@@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]:
53
53
54
54
@dataclass
55
55
class AiterMLAMetadata (MLACommonMetadata ):
56
- # The following 4 tensors are for current version of AITER MLA
56
+ # The following 5 tensors are for current version of AITER MLA
57
57
block_table_bound : Optional [torch .Tensor ] = None
58
58
# The indptr of the paged kv cache, shape: [batch_size + 1]
59
59
paged_kv_indptr : Optional [torch .Tensor ] = None
@@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata):
63
63
# the paged kv cache, shape: [batch_size]
64
64
paged_kv_last_page_lens : Optional [torch .Tensor ] = None
65
65
66
+ # This is just to make new AITER MLA API work
67
+ # -- MTP support is not added yet.
68
+ qo_indptr : Optional [torch .Tensor ] = None
69
+
66
70
@property
67
71
def prefill_metadata (self ):
68
72
prefill_metadata = super ().prefill_metadata
@@ -74,6 +78,7 @@ def prefill_metadata(self):
74
78
prefill_metadata \
75
79
.paged_kv_last_page_lens = self .paged_kv_last_page_lens
76
80
prefill_metadata .block_table_bound = self .block_table_bound
81
+ prefill_metadata .qo_indptr = self .qo_indptr
77
82
78
83
# update the cache
79
84
self ._cached_prefill_metadata = self .__class__ (
@@ -93,6 +98,7 @@ def decode_metadata(self):
93
98
decode_metadata \
94
99
.paged_kv_last_page_lens = self .paged_kv_last_page_lens
95
100
decode_metadata .block_table_bound = self .block_table_bound
101
+ decode_metadata .qo_indptr = self .qo_indptr
96
102
97
103
# update the cache
98
104
self ._cached_decode_metadata = self .__class__ (
@@ -136,6 +142,7 @@ def prepare(self):
136
142
self .paged_kv_indptr : list [int ] = [0 ]
137
143
self .paged_kv_last_page_lens : list [int ] = []
138
144
self .total_blocks = 0
145
+ self .qo_indptr : list [int ] = [0 ]
139
146
140
147
def _add_seq_group (self , inter_data , chunked_prefill_enabled : bool ,
141
148
prefix_cache_hit : bool ):
@@ -210,6 +217,7 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
210
217
self .paged_kv_indices .extend (block_table [:block_table_bound ])
211
218
self .paged_kv_indptr .append (self .paged_kv_indptr [- 1 ] +
212
219
block_table_bound )
220
+ self .qo_indptr .append (self .qo_indptr [- 1 ] + 1 )
213
221
214
222
last_page_len = seq_len % self .block_size
215
223
if last_page_len == 0 :
@@ -228,6 +236,8 @@ def build(self, seq_lens: list[int], query_lens: list[int],
228
236
self .paged_kv_indptr .extend ([last_paged_kv_indptr ] *
229
237
cuda_graph_pad_size )
230
238
self .paged_kv_last_page_lens .extend ([0 ] * cuda_graph_pad_size )
239
+ last_qo_indptr = self .qo_indptr [- 1 ]
240
+ self .qo_indptr .extend ([last_qo_indptr ] * cuda_graph_pad_size )
231
241
232
242
# For current version of AITER MLA
233
243
if len (self .paged_kv_indptr ) > 0 :
@@ -247,16 +257,22 @@ def build(self, seq_lens: list[int], query_lens: list[int],
247
257
1 ,
248
258
device = device ,
249
259
dtype = torch .int )
260
+
261
+ qo_indptr = torch .tensor (self .qo_indptr ,
262
+ device = device ,
263
+ dtype = torch .int )
250
264
else :
251
265
paged_kv_indices_tensor = None
252
266
paged_kv_indptr_tensor = None
253
267
paged_kv_last_page_lens_tensor = None
254
268
block_table_bound_tensor = None
269
+ qo_indptr = None
255
270
256
271
metadata .paged_kv_indptr = paged_kv_indptr_tensor
257
272
metadata .paged_kv_indices = paged_kv_indices_tensor
258
273
metadata .paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
259
274
metadata .block_table_bound = block_table_bound_tensor
275
+ metadata .qo_indptr = qo_indptr
260
276
261
277
return metadata
262
278
@@ -265,21 +281,25 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
265
281
266
282
@contextmanager
267
283
def graph_capture (self , max_batch_size : int ):
268
- kv_indices , kv_indptr , last_page_lens = get_aiter_mla_metadata (
269
- max_batch_size = max_batch_size ,
270
- block_size = self .runner .block_size ,
271
- max_block_per_batch = self .runner .get_max_block_per_batch (),
272
- device = self .runner .device )
284
+ kv_indices , kv_indptr , last_page_lens , qo_indptr = \
285
+ get_aiter_mla_metadata (
286
+ max_batch_size = max_batch_size ,
287
+ block_size = self .runner .block_size ,
288
+ max_block_per_batch = \
289
+ self .runner .get_max_block_per_batch (),
290
+ device = self .runner .device )
273
291
self ._paged_kv_indices_tensor = kv_indices
274
292
self ._paged_kv_indptr_tensor = kv_indptr
275
293
self ._paged_kv_last_page_lens_tensor = last_page_lens
294
+ self ._qo_indptr_tensor = qo_indptr
276
295
277
296
with super ().graph_capture (max_batch_size ):
278
297
yield
279
298
280
299
del self ._paged_kv_indices_tensor
281
300
del self ._paged_kv_indptr_tensor
282
301
del self ._paged_kv_last_page_lens_tensor
302
+ del self ._qo_indptr_tensor
283
303
284
304
def graph_capture_get_metadata_for_batch (
285
305
self ,
@@ -293,10 +313,12 @@ def graph_capture_get_metadata_for_batch(
293
313
paged_kv_indices = self ._paged_kv_indices_tensor
294
314
paged_kv_last_page_lens = self ._paged_kv_last_page_lens_tensor [:
295
315
batch_size ]
316
+ qo_indptr = self ._qo_indptr_tensor [:batch_size + 1 ]
296
317
297
318
metadata .paged_kv_indptr = paged_kv_indptr
298
319
metadata .paged_kv_indices = paged_kv_indices
299
320
metadata .paged_kv_last_page_lens = paged_kv_last_page_lens
321
+ metadata .qo_indptr = qo_indptr
300
322
301
323
return metadata
302
324
@@ -313,6 +335,7 @@ def get_graph_input_buffers(self,
313
335
input_buffers [
314
336
"paged_kv_last_page_lens" ] = attn_metadata .\
315
337
decode_metadata .paged_kv_last_page_lens
338
+ input_buffers ['qo_indptr' ] = attn_metadata .qo_indptr
316
339
317
340
return input_buffers
318
341
@@ -332,6 +355,8 @@ def prepare_graph_input_buffers(self,
332
355
input_buffers ["paged_kv_last_page_lens" ].copy_ (
333
356
attn_metadata .decode_metadata .paged_kv_last_page_lens ,
334
357
non_blocking = True )
358
+ input_buffers ["qo_indptr" ].copy_ (
359
+ attn_metadata .decode_metadata .qo_indptr , non_blocking = True )
335
360
336
361
337
362
class AiterMLAImpl (MLACommonImpl [AiterMLAMetadata ]):
@@ -372,11 +397,9 @@ def _flash_attn_varlen_diff_headdims(
372
397
softmax_scale : float , return_softmax_lse : bool ,
373
398
** kwargs ) -> Union [tuple [torch .Tensor , ...], torch .Tensor ]:
374
399
output = self .flash_attn_varlen_func (
375
- q = q ,
376
- k = k ,
377
- v = v ,
378
- softmax_scale = softmax_scale ,
379
- return_lse = return_softmax_lse ,
400
+ q ,
401
+ k ,
402
+ v ,
380
403
** kwargs ,
381
404
)
382
405
@@ -396,7 +419,7 @@ def _forward_decode(
396
419
B = q_nope .shape [0 ]
397
420
398
421
q = torch .cat ([q_nope , q_pe ], dim = - 1 )
399
- o = torch .zeros (B ,
422
+ o = torch .empty (B ,
400
423
self .num_heads ,
401
424
self .kv_lora_rank ,
402
425
dtype = q .dtype ,
@@ -405,6 +428,8 @@ def _forward_decode(
405
428
kv_buffer = kv_c_and_k_pe_cache .unsqueeze (2 )
406
429
407
430
aiter_mla_decode_fwd (q , kv_buffer , o , self .scale ,
431
+ attn_metadata .qo_indptr ,
432
+ attn_metadata .max_query_len ,
408
433
attn_metadata .paged_kv_indptr ,
409
434
attn_metadata .paged_kv_indices ,
410
435
attn_metadata .paged_kv_last_page_lens )
0 commit comments