24
24
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
25
25
AttentionLayer , AttentionType )
26
26
from vllm .attention .backends .utils import CommonAttentionState
27
+ from vllm .config import get_current_vllm_config
27
28
from vllm .forward_context import ForwardContext , get_forward_context
28
29
from vllm .utils import direct_register_custom_op
29
30
from vllm .v1 .core .sched .output import SchedulerOutput
30
31
from vllm .v1 .worker .gpu_input_batch import InputBatch
31
32
33
+ from vllm_ascend .attention .utils import \
34
+ AscendCommonAttentionMetadata as CommonAttentionMetadata
32
35
from vllm_ascend .ops .attention import vanilla_chunked_prefill
36
+ from vllm_ascend .utils import get_graph_params
33
37
34
38
35
39
class AscendAttentionBackend (AttentionBackend ):
@@ -114,6 +118,7 @@ class AscendMetadata:
114
118
query_start_loc : torch .Tensor
115
119
query_lens : torch .Tensor
116
120
seq_lens : torch .Tensor
121
+ seq_lens_list : list
117
122
# Maximum query length in the batch. None for decoding.
118
123
max_query_len : Optional [int ] = None
119
124
# (num_tokens,). The indices of the token slots that input tokens will be
@@ -149,37 +154,69 @@ def build(self,
149
154
num_reqs ,
150
155
num_actual_tokens ,
151
156
max_query_len ,
152
- common_prefix_len ,
153
- enable_dbo_across_dp : bool = False ):
157
+ common_attn_metadata : CommonAttentionMetadata ,
158
+ enable_dbo_across_dp : bool = False ,
159
+ * args ,
160
+ ** kwargs ):
154
161
155
162
block_table = self .runner .input_batch .block_table [0 ].get_device_tensor (
156
163
)
157
164
block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
158
165
block_table [:num_reqs ])
159
166
160
- query_lens = self .runner .query_lens
161
- seq_lens = self .runner .seq_lens_cpu [:num_reqs ]
162
- slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
163
- self .runner .device , non_blocking = True )
167
+ query_start_loc = common_attn_metadata .query_start_loc
168
+ seq_lens = common_attn_metadata .seq_lens
169
+ # TODO: Refactor these two param to common metadata in runners,
170
+ # preparing for the hybrid KV groups feature
171
+ query_lens = common_attn_metadata .query_lens if common_attn_metadata .query_lens is not None else self .runner .query_lens
172
+ seq_lens_list = common_attn_metadata .seq_lens_list if common_attn_metadata .seq_lens_list is not None else self .runner .seq_lens_list
173
+
174
+ slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
164
175
attn_mask = self .runner .attn_mask
165
176
attn_state = self .runner .attn_state
166
- query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
167
- query_start_loc = query_start_loc_cpu .to (self .runner .device ,
168
- non_blocking = True )
169
177
170
178
attn_metadata = AscendMetadata (
171
179
num_actual_tokens = num_actual_tokens ,
172
180
block_tables = block_table ,
173
181
query_start_loc = query_start_loc ,
174
182
query_lens = query_lens ,
175
183
seq_lens = seq_lens ,
184
+ seq_lens_list = seq_lens_list ,
176
185
max_query_len = max_query_len ,
177
186
slot_mapping = slot_mapping ,
178
187
attn_mask = attn_mask ,
179
188
attn_state = attn_state ,
180
189
enable_dbo_across_dp = enable_dbo_across_dp )
181
190
return attn_metadata
182
191
192
+ def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
193
+ num_scheduled_tokens , attn_state ):
194
+ if attn_state == AscendAttentionState .DecodeOnly :
195
+ # NOTE: We only need to pay attention to seq_lens_list and block_table here
196
+ common_attn_metadata = CommonAttentionMetadata (seq_lens_list = [2 ] *
197
+ num_reqs )
198
+
199
+ block_table = self .runner .input_batch .block_table [0 ].block_table
200
+ block_table [:num_reqs , 0 ] = torch .arange (1 ,
201
+ num_reqs + 1 ,
202
+ device = block_table .device ,
203
+ dtype = block_table .dtype )
204
+
205
+ attn_metadata = self .build (
206
+ num_reqs = num_reqs ,
207
+ num_actual_tokens = num_actual_tokens ,
208
+ max_query_len = num_scheduled_tokens .max (),
209
+ common_prefix_len = 0 ,
210
+ common_attn_metadata = common_attn_metadata ,
211
+ )
212
+ else :
213
+ raise NotImplementedError (
214
+ "Currently we only support building dummy metadata for DecodeOnly state"
215
+ )
216
+
217
+ attn_metadata .attn_state = attn_state
218
+ return attn_metadata
219
+
183
220
184
221
class AscendAttentionBackendImpl (AttentionImpl ):
185
222
@@ -217,6 +254,10 @@ def __init__(
217
254
self .key_cache = None
218
255
self .value_cache = None
219
256
257
+ vllm_config = get_current_vllm_config ()
258
+ self .full_graph = vllm_config .compilation_config .full_cuda_graph
259
+ self .block_size = vllm_config .cache_config .block_size
260
+
220
261
def forward (
221
262
self ,
222
263
layer : AttentionLayer ,
@@ -228,21 +269,7 @@ def forward(
228
269
output : Optional [torch .Tensor ] = None ,
229
270
trace_flag : bool = True ,
230
271
) -> torch .Tensor :
231
- """Forward pass with Ascend attention.
232
- Args:
233
- query: shape = [batch_size, seq_len, num_heads * head_size]
234
- key: shape = [batch_size, seq_len, num_kv_heads * head_size]
235
- value: shape = [batch_size, seq_len, num_kv_heads * head_size]
236
- kv_cache: shape = [2, num_blocks, block_size,
237
- num_kv_heads, head_size]
238
- key_cache = [num_blocks, block_size,
239
- num_kv_heads, head_size]
240
- value_cache = [num_blocks, block_size,
241
- num_kv_heads, head_size]
242
- attn_metadata: Metadata for attention.
243
- Returns:
244
- shape = [batch_size * seq_len, num_heads, head_size]
245
- """
272
+ """Forward pass with Ascend attention."""
246
273
num_tokens = query .shape [0 ]
247
274
if output is None :
248
275
output = torch .empty (num_tokens ,
@@ -322,16 +349,92 @@ def forward(
322
349
scale_value = self .scale ,
323
350
out = output )
324
351
elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
325
- torch_npu ._npu_paged_attention (
326
- query = query ,
327
- key_cache = self .key_cache ,
328
- value_cache = self .value_cache ,
329
- num_kv_heads = self .num_kv_heads ,
330
- num_heads = self .num_heads ,
331
- scale_value = self .scale ,
332
- block_table = attn_metadata .block_tables ,
333
- context_lens = attn_metadata .seq_lens ,
334
- out = output )
352
+ if self .full_graph :
353
+ graph_params = get_graph_params ()
354
+ q = query .view (num_tokens , - 1 , self .hidden_size )
355
+ k = self .key_cache .view ( # type: ignore
356
+ - 1 , self .block_size ,
357
+ self .num_kv_heads * self .head_size )
358
+ v = self .value_cache .view ( # type: ignore
359
+ - 1 , self .block_size ,
360
+ self .num_kv_heads * self .head_size )
361
+ actual_seq_lens = attn_metadata .seq_lens_list
362
+ attn_args = {
363
+ "query" : q ,
364
+ "key" : k ,
365
+ "value" : v ,
366
+ "actual_seq_lengths_kv" : actual_seq_lens ,
367
+ "block_table" : attn_metadata .block_tables ,
368
+ "num_heads" : self .num_heads ,
369
+ "scale" : self .scale ,
370
+ "input_layout" : "BSH" ,
371
+ "num_key_value_heads" : self .num_kv_heads ,
372
+ "block_size" : self .block_size ,
373
+ }
374
+
375
+ # Prepare tensors for attention output
376
+ # TODO: Refactor this to step-level instead of layer-level
377
+ attn_output = torch .empty (num_tokens ,
378
+ 1 ,
379
+ self .hidden_size ,
380
+ dtype = output .dtype ,
381
+ device = output .device )
382
+ softmax_lse = torch .empty (num_tokens ,
383
+ dtype = output .dtype ,
384
+ device = output .device )
385
+
386
+ # Get workspace from cache or calculate it if not present.
387
+ workspace = graph_params .workspaces .get (num_tokens )
388
+ if workspace is None :
389
+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
390
+ ** attn_args )
391
+ graph_params .workspaces [num_tokens ] = workspace
392
+
393
+ forward_context = get_forward_context ()
394
+ if not forward_context .capturing :
395
+ # Execute attention kernel directly in non-capturing mode
396
+ torch .ops .npu .npu_fused_infer_attention_score .out (
397
+ workspace = workspace ,
398
+ out = [attn_output , softmax_lse ],
399
+ ** attn_args )
400
+ else :
401
+ # Handle graph capturing mode
402
+ stream = torch_npu .npu .current_stream ()
403
+
404
+ event = torch .npu .ExternalEvent ()
405
+ event .wait (stream )
406
+ event .reset (stream )
407
+ graph_params .events [num_tokens ].append (event )
408
+
409
+ graph_params .attn_params [num_tokens ].append (
410
+ (q , k , v , actual_seq_lens ,
411
+ attn_metadata .block_tables , self .num_heads ,
412
+ self .scale , self .num_kv_heads , attn_output ,
413
+ softmax_lse ))
414
+
415
+ torch .npu .graph_task_group_begin (stream )
416
+ torch .ops .npu .npu_fused_infer_attention_score .out (
417
+ workspace = workspace ,
418
+ out = [attn_output , softmax_lse ],
419
+ ** attn_args )
420
+ handle = torch .npu .graph_task_group_end (stream )
421
+ graph_params .handles [num_tokens ].append (handle )
422
+
423
+ # Reshape output to match the expected format
424
+ output .copy_ (
425
+ attn_output .view (num_tokens , self .num_heads ,
426
+ self .head_size ))
427
+ else :
428
+ torch_npu ._npu_paged_attention (
429
+ query = query ,
430
+ key_cache = self .key_cache ,
431
+ value_cache = self .value_cache ,
432
+ num_kv_heads = self .num_kv_heads ,
433
+ num_heads = self .num_heads ,
434
+ scale_value = self .scale ,
435
+ block_table = attn_metadata .block_tables ,
436
+ context_lens = attn_metadata .seq_lens ,
437
+ out = output )
335
438
# Normal V1 situation.
336
439
else :
337
440
# use chunked prefill for head size 192 scenario, like deepseek
0 commit comments