Skip to content

[Feature] support c16 prefix_cache in flash_attention_v3 #2766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from

Conversation

lizhenyun01
Copy link

support c16 prefix_cache in flash_attention_v3

Copy link

paddle-bot bot commented Jul 9, 2025

Thanks for your contribution!

@Jiang-Jia-Jun Jiang-Jia-Jun requested a review from Copilot July 9, 2025 04:20
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for 16-bit (c16) prefix cache in flash_attention_v3 by introducing a new dequantization kernel and updating the dispatch logic.

  • Added append_dequant_cache_kv_c16 kernel for c16 (no-quant) cache paths.
  • Extended AppendDequantCache to launch the c16 kernel when cache_quant_type == "none".
  • Removed an unused include for flash_attn_v3_kernel.h.

const uint32_t batch_id = batch_ids[tile_idx];
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
if (seq_lens_this_time <= 0) {
Copy link
Preview

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition is comparing a pointer seq_lens_this_time instead of its value. You likely meant to check the sequence length element, e.g., seq_lens_this_time[tile_idx] <= 0 or seq_lens_this_time[batch_id] <= 0.

Suggested change
if (seq_lens_this_time <= 0) {
if (seq_lens_this_time[batch_id] <= 0) {

Copilot uses AI. Check for mistakes.

Comment on lines 253 to 254
k_tile_ptr0[8] = frag_dq_T[6];
k_tile_ptr0[9] = frag_dq_T[7];
Copy link
Preview

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This write is using k_tile_ptr0 instead of k_tile_ptr1 for the second half of the fragment. It should be k_tile_ptr1[8] = frag_dq_T[6]; (and similarly for index 9).

Suggested change
k_tile_ptr0[8] = frag_dq_T[6];
k_tile_ptr0[9] = frag_dq_T[7];
k_tile_ptr1[8] = frag_dq_T[6];
k_tile_ptr1[9] = frag_dq_T[7];

Copilot uses AI. Check for mistakes.

Comment on lines 314 to 315
v_tile_ptr0[8] = frag_dq_T[6];
v_tile_ptr0[9] = frag_dq_T[7];
Copy link
Preview

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the K path, this write uses v_tile_ptr0 instead of v_tile_ptr1 for the second row. It should be v_tile_ptr1[8] = frag_dq_T[6]; (and likewise for element 9).

Suggested change
v_tile_ptr0[8] = frag_dq_T[6];
v_tile_ptr0[9] = frag_dq_T[7];
v_tile_ptr1[8] = frag_dq_T[6];
v_tile_ptr1[9] = frag_dq_T[7];

Copilot uses AI. Check for mistakes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant