Skip to content

[Feature] support c16 prefix_cache in flash_attention_v3 #2766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 211 additions & 15 deletions custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "paddle/extension.h"
#include "paddle/phi/core/memory/memcpy.h"
#include "encoder_write_cache_with_rope_impl.cuh"
#include "paddle/phi/kernels/gpu/flash_attn_v3_kernel.h"
#include "paddle/phi/backends/context_pool.h"
#include "remote_cache_kv_ipc.h"

Expand Down Expand Up @@ -148,13 +147,188 @@ void gqa_rotary_qk_split_variable(
dim_head);
}

template <typename T,
typename CacheT,
uint32_t HEAD_DIM,
uint32_t BLOCK_SIZE,
uint32_t NUM_WARPS=4>
__global__ void append_cache_kv_c16(
const T *__restrict__ cache_k,
const T *__restrict__ cache_v,
T *__restrict__ k_out,
T *__restrict__ v_out,
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ cu_seqlens_k,
const int *__restrict__ block_tables,
const int *batch_ids,
const int *tile_ids_per_batch,
const int max_blocks_per_seq,
const int kv_num_heads) {
// start_kv_idx: 每个block的起始kv_idx
// batch_id:每个block属于的batch
// TODO: 1.scale预取 2.frag_dq_T复用 3.流水线编排 4.store访存合并 5.cacheT支持(int8/fp8)
const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z;
const uint32_t tid = threadIdx.x, wid = threadIdx.y;

const uint32_t batch_id = batch_ids[tile_idx];
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
if (seq_lens_this_time[batch_id] <= 0) {
return;
}

const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq;
uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE];
// cache_kv idx
uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
uint32_t block_stride = kv_num_heads * kv_h_stride;
const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride;
const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride;

// k_out v_out idx
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM;
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前k block起始指针
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前v block起始指针

uint32_t kv_frag[4];
T *frag_dq_T = reinterpret_cast<T *>(kv_frag);

constexpr uint32_t num_vecs_per_head =
HEAD_DIM / num_elems_per_128b<CacheT>();
constexpr uint32_t inv_kv_stride = 8 / num_vecs_per_head;

extern __shared__ uint8_t smem[];
smem_t k_smem(smem);
uint32_t k_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp

uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);

uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
tid % 8 * num_elems_per_128b<CacheT>();

// load k_smem 行是64 列是128
for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行
for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
k_smem_offset_w =
k_smem.advance_offset_by_column<8, num_vecs_per_head>(k_smem_offset_w, fy);
k_read_idx += 8 * num_elems_per_128b<CacheT>();
}
k_smem_offset_w =
k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_w) - 8;
k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b<CacheT>();
}
commit_group();
wait_group<0>();
__syncthreads();

// deal k_smem 行是64 列是128
for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16)
uint32_t col_idx = fy * 16 + tid % 4 * 2;
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
// 存储
/***
r0c0,r0c1, r0c8,r0c9
r8c0,r8c1, r8c8,r8c9
***/
T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx;
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;

if (row_idx < end_idx) {
k_tile_ptr0[0] = frag_dq_T[0];
k_tile_ptr0[1] = frag_dq_T[1];
k_tile_ptr0[8] = frag_dq_T[4];
k_tile_ptr0[9] = frag_dq_T[5];
}

if (row_idx + 8 < end_idx) {
k_tile_ptr1[0] = frag_dq_T[2];
k_tile_ptr1[1] = frag_dq_T[3];
k_tile_ptr1[8] = frag_dq_T[6];
k_tile_ptr1[9] = frag_dq_T[7];
}
k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>(
k_smem_offset_r, fy);
}
k_smem_offset_r =
k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_r) - 16;
}
// ================v================
smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT));
uint32_t v_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);

uint32_t v_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
tid % 8 * num_elems_per_128b<CacheT>();

// load v_smem 行是64 列是128
for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行
for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
v_smem_offset_w =
v_smem.advance_offset_by_column<8, num_vecs_per_head>(v_smem_offset_w, fy);
v_read_idx += 8 * num_elems_per_128b<CacheT>();
}
v_smem_offset_w =
v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_w) - 8;
v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b<CacheT>();
}
commit_group();
wait_group<0>();
__syncthreads();

// deal v_smem 行是64 列是128

for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16)
uint32_t col_idx = fy * 16 + tid % 4 * 2;
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
// 存储
/***
r0c0,r0c1, r0c8,r0c9
r8c0,r8c1, r8c8,r8c9
***/
T *v_tile_ptr0 = v_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx;
T *v_tile_ptr1 = v_tile_ptr0 + 8 * kv_t_stride;

if (row_idx < end_idx) {
v_tile_ptr0[0] = frag_dq_T[0];
v_tile_ptr0[1] = frag_dq_T[1];
v_tile_ptr0[8] = frag_dq_T[4];
v_tile_ptr0[9] = frag_dq_T[5];
}

if (row_idx + 8 < end_idx) {
v_tile_ptr1[0] = frag_dq_T[2];
v_tile_ptr1[1] = frag_dq_T[3];
v_tile_ptr1[8] = frag_dq_T[6];
v_tile_ptr1[9] = frag_dq_T[7];
}
v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>(
v_smem_offset_r, fy);
}
v_smem_offset_r =
v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_r) - 16;
}
}

template <typename T,
typename CacheT,
uint32_t HEAD_DIM,
uint32_t BLOCK_SIZE,
uint32_t NUM_WARPS=4,
bool IS_FP8=false>
__global__ void append_dequant_cache_kv_c8(
__global__ void append_cache_kv_c8(
const CacheT *__restrict__ cache_k,
const CacheT *__restrict__ cache_v,
T *__restrict__ k_out,
Expand Down Expand Up @@ -214,7 +388,7 @@ __global__ void append_dequant_cache_kv_c8(

uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);

uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
tid % 8 * num_elems_per_128b<CacheT>();

Expand Down Expand Up @@ -328,7 +502,7 @@ __global__ void append_dequant_cache_kv_c8(
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;


convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
#ifdef C8_DEBUG
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
Expand All @@ -353,7 +527,7 @@ __global__ void append_dequant_cache_kv_c8(
}

template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
void AppendDequantCache(
void AppendCacheKV(
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::Tensor &cache_k_dequant_scales,
Expand All @@ -371,19 +545,41 @@ void AppendDequantCache(
paddle::Tensor *k_out,
paddle::Tensor *v_out,
const cudaStream_t& stream
) {
) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
constexpr int NUM_WARPS = 4;
int block_num = cache_num_blocks_x.data<int>()[0];
dim3 grids(block_num, 1, kv_num_heads);
dim3 blocks(32, NUM_WARPS);

constexpr int NUM_WARPS = 4;
int block_num = cache_num_blocks_x.data<int>()[0];
dim3 grids(block_num, 1, kv_num_heads);
dim3 blocks(32, NUM_WARPS);
if (cache_quant_type == "none") {
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2;
auto kernel_func = append_cache_kv_c16<NV_TYPE, NV_TYPE, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>;

if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel_func,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
}
kernel_func<<<grids, blocks, smem_size, stream>>>(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_k.data<int>(),
block_tables.data<int>(),
cache_batch_ids.data<int>(),
cache_tile_ids_per_batch.data<int>(),
max_blocks_per_seq,
kv_num_heads
);
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;

auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
auto kernel_func = append_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
if (cache_quant_type == "cache_fp8") {
kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, true>;
kernel_func = append_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, true>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel_func,
Expand Down Expand Up @@ -561,7 +757,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
}

if (token_num < kv_token_num) {
AppendDequantCache<data_t, 128, 64>(
AppendCacheKV<data_t, 128, 64>(
key_cache,
value_cache,
cache_k_dequant_scales.get(),
Expand Down