2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
"""Attention layer with PagedAttention and Triton prefix prefill."""
4
4
from dataclasses import dataclass
5
- from typing import TYPE_CHECKING , Any , ClassVar , Optional
5
+ from typing import Any , ClassVar , Optional
6
6
7
7
import torch
8
8
14
14
chunked_prefill_paged_decode )
15
15
from vllm .attention .ops .paged_attn import PagedAttention
16
16
from vllm .attention .ops .triton_unified_attention import unified_attention
17
+ from vllm .config import VllmConfig
17
18
from vllm .logger import init_logger
18
19
from vllm .platforms import current_platform
19
20
from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
20
21
from vllm .v1 .attention .backends .utils import (
21
22
AttentionMetadataBuilder , CommonAttentionMetadata ,
22
23
make_local_attention_virtual_batches )
23
24
from vllm .v1 .kv_cache_interface import AttentionSpec
24
- from vllm .v1 .worker .block_table import BlockTable
25
-
26
- if TYPE_CHECKING :
27
- from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
28
25
29
26
logger = init_logger (__name__ )
30
27
@@ -75,12 +72,21 @@ class TritonAttentionMetadataBuilder(
75
72
AttentionMetadataBuilder [TritonAttentionMetadata ]):
76
73
full_cudagraph_supported : ClassVar [bool ] = True
77
74
78
- def __init__ (self , runner : "GPUModelRunner" , kv_cache_spec : AttentionSpec ,
79
- block_table : BlockTable ):
80
- self .runner = runner
75
+ def __init__ (self , kv_cache_spec : AttentionSpec , vllm_config : VllmConfig ,
76
+ device : torch . device ):
77
+ self .device = device
81
78
self .block_size = kv_cache_spec .block_size
82
79
self .kv_cache_spec = kv_cache_spec
83
- self .block_table = block_table
80
+
81
+ model_config = vllm_config .model_config
82
+ self .num_heads_q = model_config .get_num_attention_heads (
83
+ vllm_config .parallel_config )
84
+ self .num_heads_kv = model_config .get_num_kv_heads (
85
+ vllm_config .parallel_config )
86
+ self .headdim = model_config .get_head_size ()
87
+
88
+ self .attention_chunk_size = getattr (vllm_config .scheduler_config ,
89
+ 'attention_chunk_size' , None )
84
90
85
91
def build_for_cudagraph_capture (
86
92
self , common_attn_metadata : CommonAttentionMetadata
@@ -96,42 +102,32 @@ def build(self,
96
102
common_prefix_len : int ,
97
103
common_attn_metadata : CommonAttentionMetadata ,
98
104
fast_build : bool = False ) -> TritonAttentionMetadata :
99
- num_reqs = common_attn_metadata .num_reqs
100
105
num_actual_tokens = common_attn_metadata .num_actual_tokens
101
106
max_query_len = common_attn_metadata .max_query_len
102
107
103
- max_seq_len = int (self . runner . seq_lens_np [: num_reqs ] .max ())
108
+ max_seq_len = int (common_attn_metadata . seq_lens_cpu .max ())
104
109
query_start_loc = common_attn_metadata .query_start_loc
105
110
seq_lens = common_attn_metadata .seq_lens
106
- block_table = self .block_table
107
- block_table_tensor = block_table .get_device_tensor ()[:num_reqs ]
108
-
109
- block_table .slot_mapping [:num_actual_tokens ].copy_ (
110
- block_table .slot_mapping_cpu [:num_actual_tokens ],
111
- non_blocking = True )
112
- # Fill unused with -1. Needed for reshape_and_cache in full cuda graph
113
- # mode.
114
- block_table .slot_mapping [num_actual_tokens :].fill_ (- 1 )
115
-
116
- slot_mapping = block_table .slot_mapping [:num_actual_tokens ]
111
+ block_table_tensor = common_attn_metadata .block_table_tensor
112
+ slot_mapping = common_attn_metadata .slot_mapping
117
113
118
114
# for local attention
119
115
local_attn_metadata = None
120
- if self .runner . attention_chunk_size is not None :
116
+ if self .attention_chunk_size is not None :
121
117
seqlens_q_local_np , virt_q_cu_seqlens_np , virt_k_seqlens_np , \
122
118
virt_block_table_tensor = make_local_attention_virtual_batches (
123
- self .runner . attention_chunk_size ,
124
- self . runner . query_start_loc_np [: num_reqs + 1 ] ,
125
- self . runner . seq_lens_np [: num_reqs ] ,
119
+ self .attention_chunk_size ,
120
+ common_attn_metadata . query_start_loc_cpu . numpy () ,
121
+ common_attn_metadata . seq_lens_cpu . numpy () ,
126
122
block_table_tensor ,
127
123
self .block_size ,
128
124
)
129
125
local_query_start_loc = torch .from_numpy (virt_q_cu_seqlens_np ).to (
130
- self .runner . device , non_blocking = True )
126
+ self .device , non_blocking = True )
131
127
local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
132
- self .runner . device , non_blocking = True )
133
- local_max_query_len = seqlens_q_local_np .max ()
134
- local_max_seq_len = virt_k_seqlens_np .max ()
128
+ self .device , non_blocking = True )
129
+ local_max_query_len = seqlens_q_local_np .max (). item ()
130
+ local_max_seq_len = virt_k_seqlens_np .max (). item ()
135
131
136
132
local_attn_metadata = TritonAttentionMetadata \
137
133
.LocalAttentionMetadata (
@@ -148,14 +144,13 @@ def build(self,
148
144
if use_cascade :
149
145
cu_prefix_query_lens = torch .tensor ([0 , num_actual_tokens ],
150
146
dtype = torch .int32 ,
151
- device = self .runner . device )
147
+ device = self .device )
152
148
prefix_kv_lens = torch .tensor ([common_prefix_len ],
153
149
dtype = torch .int32 ,
154
- device = self .runner . device )
155
- suffix_kv_lens = (self . runner . seq_lens_np [: num_reqs ] -
150
+ device = self .device )
151
+ suffix_kv_lens = (common_attn_metadata . seq_lens_cpu -
156
152
common_prefix_len )
157
- suffix_kv_lens = torch .from_numpy (suffix_kv_lens ).to (
158
- self .runner .device )
153
+ suffix_kv_lens = suffix_kv_lens .to (self .device )
159
154
else :
160
155
cu_prefix_query_lens = None
161
156
prefix_kv_lens = None
0 commit comments