@@ -167,161 +167,6 @@ def __init__(self,
167
167
def lambda_init_fn (self , depth ):
168
168
return 0.8 - 0.6 * math .exp (- 0.3 * depth )
169
169
170
-
171
- def split_heads (self , x ):
172
- # split by num_heads, the stripe pattern is friendly to tensor parallel.
173
- x = rearrange (x , "... (H two) D -> ... H two D" , two = 2 )
174
- x1 = x [..., 0 , :]
175
- x2 = x [..., 1 , :]
176
- return x1 .contiguous (), x2 .contiguous ()
177
-
178
- def split_kv_cache (self , x ):
179
- # split by num_heads, the stripe pattern is friendly to tensor parallel.
180
- if x .numel () == 0 :
181
- return torch .empty (0 ), torch .empty (0 )
182
-
183
- x1 , x2 = x [0 ], x [1 ]
184
- return x1 , x2
185
-
186
- def forward_decode (
187
- self ,
188
- query : torch .Tensor ,
189
- k_cache : torch .Tensor ,
190
- v_cache : torch .Tensor ,
191
- attn_metadata : AttentionMetadata ,
192
- ):
193
- if not attn_metadata .decode_metadata :
194
- block_tables_arg = attn_metadata .cross_layer_shared_block_tables
195
- else :
196
- block_tables_arg = attn_metadata .block_tables
197
-
198
- output = flash_attn_with_kvcache (
199
- q = query .unsqueeze (1 ),
200
- k_cache = k_cache ,
201
- v_cache = v_cache ,
202
- block_table = block_tables_arg ,
203
- cache_seqlens = attn_metadata .seq_lens_tensor ,
204
- softmax_scale = self .attn .impl .scale ,
205
- causal = True ,
206
- window_size = self .attn .impl .sliding_window ,
207
- alibi_slopes = self .attn .impl .alibi_slopes ,
208
- softcap = self .attn .impl .logits_soft_cap ,
209
- ).squeeze (1 )
210
- return output
211
-
212
- def populate_kv_cache (self ,
213
- key ,
214
- value ,
215
- kv_cache ,
216
- attn_metadata ):
217
- if (kv_cache .numel () > 0 ):
218
- if (key is not None ) and (value is not None ):
219
- updated_slot_mapping = attn_metadata .slot_mapping
220
- # previous_key_cache_sum = key_cache.sum()
221
- # previous_value_cache_sum = value_cache.sum()
222
-
223
- torch .ops ._C_cache_ops .reshape_and_cache_flash (
224
- key ,
225
- value ,
226
- kv_cache [0 ],
227
- kv_cache [1 ],
228
- updated_slot_mapping .flatten (),
229
- self .attn .impl .kv_cache_dtype ,
230
- self ._k_scale ,
231
- self ._v_scale ,
232
- )
233
- # assert key_cache.sum() - previous_key_cache_sum == key.sum(), "key_cache sum mismatch"
234
- # assert value_cache.sum() - previous_value_cache_sum == value.sum(), "value_cache sum mismatch"
235
- # if key_cache.sum() - previous_key_cache_sum != key.sum():
236
- # print("key_cache sum mismatch")
237
- # if value_cache.sum() - previous_value_cache_sum != value.sum():
238
- # print("value_cache sum mismatch")
239
-
240
- def forward_customized (
241
- self ,
242
- query : torch .Tensor ,
243
- key : Optional [torch .Tensor ],
244
- value : Optional [torch .Tensor ],
245
- k_cache : torch .Tensor ,
246
- v_cache : torch .Tensor ,
247
- attn_metadata : AttentionMetadata
248
- ) -> torch .Tensor :
249
-
250
- head_size = self .head_dim
251
- num_heads = self .num_heads // 2
252
- num_kv_heads = self .num_key_value_heads // 2
253
-
254
- query = query .view (- 1 , num_heads , head_size )
255
- if key is not None :
256
- assert value is not None
257
- key = key .view (- 1 , num_kv_heads , head_size )
258
- value = value .view (- 1 , num_kv_heads , head_size )
259
- else :
260
- assert value is None
261
-
262
- num_prefill_tokens = attn_metadata .num_prefill_tokens
263
- num_decode_tokens = attn_metadata .num_decode_tokens
264
- assert key .shape [0 ] == num_prefill_tokens + num_decode_tokens , "key shape mismatch"
265
- assert value .shape [0 ] == num_prefill_tokens + num_decode_tokens , "value shape mismatch"
266
-
267
- output = torch .empty_like (query )
268
- # Query for decode. KV is not needed because it is already cached.
269
- decode_query = query [num_prefill_tokens :]
270
- # QKV for prefill.
271
- query = query [:num_prefill_tokens ]
272
- if key is not None and value is not None :
273
- key = key [:num_prefill_tokens ]
274
- value = value [:num_prefill_tokens ]
275
-
276
- assert query .shape [0 ] == num_prefill_tokens , "query shape mismatch"
277
- assert decode_query .shape [0 ] == num_decode_tokens , "decode query shape mismatch"
278
-
279
- if prefill_meta := attn_metadata .prefill_metadata :
280
- # Prompt run.
281
- if k_cache .numel () == 0 or prefill_meta .block_tables .numel () == 0 :
282
- # normal attention
283
- prefill_output = flash_attn_varlen_func (
284
- q = query ,
285
- k = key ,
286
- v = value ,
287
- cu_seqlens_q = prefill_meta .seq_start_loc ,
288
- cu_seqlens_k = prefill_meta .seq_start_loc ,
289
- max_seqlen_q = prefill_meta .max_prefill_seq_len ,
290
- max_seqlen_k = prefill_meta .max_prefill_seq_len ,
291
- softmax_scale = self .attn .impl .scale ,
292
- causal = True ,
293
- window_size = self .attn .impl .sliding_window ,
294
- alibi_slopes = self .attn .impl .alibi_slopes ,
295
- softcap = self .attn .impl .logits_soft_cap ,
296
- )
297
- assert prefill_output .shape == output [:num_prefill_tokens ].shape
298
- output [:num_prefill_tokens ] = prefill_output
299
- else :
300
- raise Exception ("prefix caching not supported" )
301
-
302
- if decode_meta := attn_metadata .decode_metadata :
303
- block_tables_arg = decode_meta .block_tables
304
- try :
305
- output [num_prefill_tokens :] = flash_attn_with_kvcache (
306
- q = decode_query .unsqueeze (1 ),
307
- k_cache = k_cache ,
308
- v_cache = v_cache ,
309
- block_table = block_tables_arg ,
310
- cache_seqlens = decode_meta .seq_lens_tensor ,
311
- softmax_scale = self .attn .impl .scale ,
312
- causal = True ,
313
- window_size = self .attn .impl .sliding_window ,
314
- alibi_slopes = self .attn .impl .alibi_slopes ,
315
- softcap = self .attn .impl .logits_soft_cap ,
316
- ).squeeze (1 )
317
- except Exception as e :
318
- logger .error (
319
- f"Error in PagedAttention.forward_decode: { str (e )} " )
320
- raise e
321
-
322
- # Reshape the output tensor.
323
- return output .view (- 1 , num_heads , head_size )
324
-
325
170
def forward (
326
171
self ,
327
172
hidden_states : torch .Tensor ,
@@ -333,86 +178,9 @@ def forward(
333
178
if not self .yoco_cross : # need to generate kv-cache
334
179
qkv = self .Wqkv (hidden_states )
335
180
q , k , v = qkv .split ([self .hidden_size , self .num_key_value_heads * self .head_dim , self .num_key_value_heads * self .head_dim ], dim = - 1 )
336
- reference_attn_output = self .attn (q , k , v )
337
- # # q, k = self.rotary_emb(positions, q, k)
338
- # # reshape
339
- # q = q.view(-1, self.num_heads, self.head_dim)
340
- # k = k.view(-1, self.num_key_value_heads, self.head_dim)
341
- # v = v.view(-1, self.num_key_value_heads, self.head_dim)
342
-
343
- # q1, q2 = self.split_heads(q)
344
- # k1, k2 = self.split_heads(k)
345
- # v1, v2 = self.split_heads(v)
346
-
347
- # # kv_cache shape is (2, 2, num_blocks, block_size * num_kv_heads // 2 * head_size)
348
- # # Split by half along the first dimension.
349
- # kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
350
- # assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous"
351
- # assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous"
352
-
353
- # if kv_cache1.numel() != 0:
354
- # self.populate_kv_cache(k1, v1, kv_cache1, attn_metadata)
355
- # self.populate_kv_cache(k2, v2, kv_cache2, attn_metadata)
356
-
357
- # key_cache1, value_cache1 = self.split_kv_cache(kv_cache1)
358
- # key_cache2, value_cache2 = self.split_kv_cache(kv_cache2)
359
- # else:
360
- # key_cache1, value_cache1 = torch.empty(0), torch.empty(0)
361
- # key_cache2, value_cache2 = torch.empty(0), torch.empty(0)
362
- # attn11 = self.forward_customized(q1, k1, v1, key_cache1, value_cache1, attn_metadata)
363
- # attn12 = self.forward_customized(q1, k1, v2, key_cache1, value_cache2, attn_metadata)
364
- # attn11 = attn11.view(q1.shape)
365
- # attn12 = attn12.view(q1.shape)
366
- # attn1 = torch.cat([attn11, attn12], dim=-1)
367
-
368
- # attn21 = self.forward_customized(q2, k2, v1, key_cache2, value_cache1, attn_metadata)
369
- # attn22 = self.forward_customized(q2, k2, v2, key_cache2, value_cache2, attn_metadata)
370
- # attn21 = attn21.view(q2.shape)
371
- # attn22 = attn22.view(q2.shape)
372
- # attn2 = torch.cat([attn21, attn22], dim=-1)
373
-
374
- # lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
375
- # lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
376
- # lambda_full = lambda_1 - lambda_2 + self.lambda_init
377
-
378
- # attn = attn1 - lambda_full * attn2
379
- # # attn shape (-1, self.num_heads // 2, 2 * self.head_dim)
380
- # attn = self.subln(attn)
381
- # attn = attn * (1 - self.lambda_init)
382
- # # reshape back to 2 * num_head
383
- # attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2)
384
181
attn_output = self .attn (q , k , v )
385
182
else : # re-use the kv cache, full attention
386
183
q = self .Wqkv (hidden_states )
387
- # q = q.view(-1, self.num_heads, self.head_dim)
388
- # q1, q2 = self.split_heads(q)
389
- # # kv_cache shape is (2, num_blocks, block_size * num_kv_heads * head_size)
390
- # kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
391
- # key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1]
392
- # key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1]
393
-
394
- # attn11 = self.forward_decode(q1, key_cache1, value_cache1, attn_metadata)
395
- # attn12 = self.forward_decode(q1, key_cache1, value_cache2, attn_metadata)
396
- # attn11 = attn11.view(q1.shape)
397
- # attn12 = attn12.view(q1.shape)
398
- # attn1 = torch.cat([attn11, attn12], dim=-1)
399
-
400
- # attn21 = self.forward_decode(q2, key_cache2, value_cache1, attn_metadata)
401
- # attn22 = self.forward_decode(q2, key_cache2, value_cache2, attn_metadata)
402
- # attn21 = attn21.view(q2.shape)
403
- # attn22 = attn22.view(q2.shape)
404
- # attn2 = torch.cat([attn21, attn22], dim=-1)
405
-
406
- # lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
407
- # lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
408
- # lambda_full = lambda_1 - lambda_2 + self.lambda_init
409
- # attn = attn1 - lambda_full * attn2
410
- # attn = self.subln(attn)
411
- # attn = attn * (1 - self.lambda_init)
412
- # # reshape back to 2 * num_head
413
- # attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2)
414
-
415
-
416
184
if self .attn .kv_cache [0 ].numel () == 0 :
417
185
self .attn .kv_cache = [kv_cache ]
418
186
attn_output = self .attn (q , None , None )
0 commit comments