@@ -118,7 +118,7 @@ class AscendMetadata:
118
118
query_start_loc : torch .Tensor
119
119
query_lens : torch .Tensor
120
120
seq_lens : torch .Tensor
121
- seq_lens_list : list
121
+ seq_lens_list : Optional [ list [ int ]]
122
122
# Maximum query length in the batch. None for decoding.
123
123
max_query_len : Optional [int ] = None
124
124
# (num_tokens,). The indices of the token slots that input tokens will be
@@ -168,8 +168,9 @@ def build(self,
168
168
seq_lens = common_attn_metadata .seq_lens
169
169
# TODO: Refactor these two param to common metadata in runners,
170
170
# preparing for the hybrid KV groups feature
171
- query_lens = common_attn_metadata .query_lens if common_attn_metadata .query_lens is not None else self .runner .query_lens
172
- seq_lens_list = common_attn_metadata .seq_lens_list if common_attn_metadata .seq_lens_list is not None else self .runner .seq_lens_list
171
+ query_lens = common_attn_metadata .query_lens or self .runner .query_lens
172
+ # Since FIA for GQA is not active now, we temporarily silence it
173
+ seq_lens_list = common_attn_metadata .seq_lens_list
173
174
174
175
slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
175
176
attn_mask = self .runner .attn_mask
@@ -193,8 +194,8 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs,
193
194
num_scheduled_tokens , attn_state ):
194
195
if attn_state == AscendAttentionState .DecodeOnly :
195
196
# NOTE: We only need to pay attention to seq_lens_list and block_table here
196
- common_attn_metadata = CommonAttentionMetadata (seq_lens_list = [ 2 ] *
197
- num_reqs )
197
+ common_attn_metadata = CommonAttentionMetadata (
198
+ seq_lens = torch . empty_like ( self . runner . seq_lens_cpu ). fill_ ( 2 ) )
198
199
199
200
block_table = self .runner .input_batch .block_table [0 ].block_table
200
201
block_table [:num_reqs , 0 ] = torch .arange (1 ,
@@ -349,82 +350,42 @@ def forward(
349
350
scale_value = self .scale ,
350
351
out = output )
351
352
elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
352
- if self .full_graph :
353
- graph_params = get_graph_params ()
354
- q = query .view (num_tokens , - 1 , self .hidden_size )
355
- k = self .key_cache .view ( # type: ignore
356
- - 1 , self .block_size ,
357
- self .num_kv_heads * self .head_size )
358
- v = self .value_cache .view ( # type: ignore
359
- - 1 , self .block_size ,
360
- self .num_kv_heads * self .head_size )
361
- actual_seq_lens = attn_metadata .seq_lens_list
362
- attn_args = {
363
- "query" : q ,
364
- "key" : k ,
365
- "value" : v ,
366
- "actual_seq_lengths_kv" : actual_seq_lens ,
367
- "block_table" : attn_metadata .block_tables ,
368
- "num_heads" : self .num_heads ,
369
- "scale" : self .scale ,
370
- "input_layout" : "BSH" ,
371
- "num_key_value_heads" : self .num_kv_heads ,
372
- "block_size" : self .block_size ,
373
- }
374
-
375
- # Prepare tensors for attention output
376
- # TODO: Refactor this to step-level instead of layer-level
377
- attn_output = torch .empty (num_tokens ,
378
- 1 ,
379
- self .hidden_size ,
380
- dtype = output .dtype ,
381
- device = output .device )
382
- softmax_lse = torch .empty (num_tokens ,
383
- dtype = output .dtype ,
384
- device = output .device )
385
-
386
- # Get workspace from cache or calculate it if not present.
387
- workspace = graph_params .workspaces .get (num_tokens )
388
- if workspace is None :
389
- workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
390
- ** attn_args )
391
- graph_params .workspaces [num_tokens ] = workspace
392
-
393
- forward_context = get_forward_context ()
394
- if not forward_context .capturing :
395
- # Execute attention kernel directly in non-capturing mode
396
- torch .ops .npu .npu_fused_infer_attention_score .out (
397
- workspace = workspace ,
398
- out = [attn_output , softmax_lse ],
399
- ** attn_args )
400
- else :
401
- # Handle graph capturing mode
402
- stream = torch_npu .npu .current_stream ()
403
-
404
- event = torch .npu .ExternalEvent ()
405
- event .wait (stream )
406
- event .reset (stream )
407
- graph_params .events [num_tokens ].append (event )
408
-
409
- graph_params .attn_params [num_tokens ].append (
410
- (q , k , v , actual_seq_lens ,
411
- attn_metadata .block_tables , self .num_heads ,
412
- self .scale , self .num_kv_heads , attn_output ,
413
- softmax_lse ))
414
-
415
- torch .npu .graph_task_group_begin (stream )
416
- torch .ops .npu .npu_fused_infer_attention_score .out (
417
- workspace = workspace ,
418
- out = [attn_output , softmax_lse ],
419
- ** attn_args )
420
- handle = torch .npu .graph_task_group_end (stream )
421
- graph_params .handles [num_tokens ].append (handle )
422
-
423
- # Reshape output to match the expected format
424
- output .copy_ (
425
- attn_output .view (num_tokens , self .num_heads ,
426
- self .head_size ))
353
+ graph_params = get_graph_params ()
354
+
355
+ forward_context = get_forward_context ()
356
+ if not forward_context .capturing :
357
+ torch_npu ._npu_paged_attention (
358
+ query = query ,
359
+ key_cache = self .key_cache ,
360
+ value_cache = self .value_cache ,
361
+ num_kv_heads = self .num_kv_heads ,
362
+ num_heads = self .num_heads ,
363
+ scale_value = self .scale ,
364
+ block_table = attn_metadata .block_tables ,
365
+ context_lens = attn_metadata .seq_lens ,
366
+ out = output )
427
367
else :
368
+ # Handle graph capturing mode
369
+ stream = torch_npu .npu .current_stream ()
370
+
371
+ event = torch .npu .ExternalEvent ()
372
+ event .wait (stream )
373
+ event .reset (stream )
374
+ graph_params .events [num_tokens ].append (event )
375
+
376
+ graph_params .attn_params [num_tokens ].append ((
377
+ query ,
378
+ self .key_cache ,
379
+ self .value_cache ,
380
+ self .num_kv_heads ,
381
+ self .num_heads ,
382
+ self .scale ,
383
+ attn_metadata .block_tables ,
384
+ attn_metadata .seq_lens ,
385
+ output ,
386
+ ))
387
+
388
+ torch .npu .graph_task_group_begin (stream )
428
389
torch_npu ._npu_paged_attention (
429
390
query = query ,
430
391
key_cache = self .key_cache ,
@@ -435,6 +396,8 @@ def forward(
435
396
block_table = attn_metadata .block_tables ,
436
397
context_lens = attn_metadata .seq_lens ,
437
398
out = output )
399
+ handle = torch .npu .graph_task_group_end (stream )
400
+ graph_params .handles [num_tokens ].append (handle )
438
401
# Normal V1 situation.
439
402
else :
440
403
# use chunked prefill for head size 192 scenario, like deepseek
0 commit comments