-
Notifications
You must be signed in to change notification settings - Fork 563
[Feature] support c16 prefix_cache in flash_attention_v3 #2766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,7 +16,6 @@ | |||||||||
#include "paddle/extension.h" | ||||||||||
#include "paddle/phi/core/memory/memcpy.h" | ||||||||||
#include "encoder_write_cache_with_rope_impl.cuh" | ||||||||||
#include "paddle/phi/kernels/gpu/flash_attn_v3_kernel.h" | ||||||||||
#include "paddle/phi/backends/context_pool.h" | ||||||||||
#include "remote_cache_kv_ipc.h" | ||||||||||
|
||||||||||
|
@@ -148,6 +147,181 @@ void gqa_rotary_qk_split_variable( | |||||||||
dim_head); | ||||||||||
} | ||||||||||
|
||||||||||
template <typename T, | ||||||||||
typename CacheT, | ||||||||||
uint32_t HEAD_DIM, | ||||||||||
uint32_t BLOCK_SIZE, | ||||||||||
uint32_t NUM_WARPS=4> | ||||||||||
__global__ void append_dequant_cache_kv_c16( | ||||||||||
const T *__restrict__ cache_k, | ||||||||||
const T *__restrict__ cache_v, | ||||||||||
T *__restrict__ k_out, | ||||||||||
T *__restrict__ v_out, | ||||||||||
const int *__restrict__ seq_lens_this_time, | ||||||||||
const int *__restrict__ seq_lens_decoder, | ||||||||||
const int *__restrict__ cu_seqlens_k, | ||||||||||
const int *__restrict__ block_tables, | ||||||||||
const int *batch_ids, | ||||||||||
const int *tile_ids_per_batch, | ||||||||||
const int max_blocks_per_seq, | ||||||||||
const int kv_num_heads) { | ||||||||||
// start_kv_idx: 每个block的起始kv_idx | ||||||||||
// batch_id:每个block属于的batch | ||||||||||
// TODO: 1.scale预取 2.frag_dq_T复用 3.流水线编排 4.store访存合并 5.cacheT支持(int8/fp8) | ||||||||||
const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; | ||||||||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y; | ||||||||||
|
||||||||||
const uint32_t batch_id = batch_ids[tile_idx]; | ||||||||||
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE; | ||||||||||
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx; | ||||||||||
if (seq_lens_this_time <= 0) { | ||||||||||
return; | ||||||||||
} | ||||||||||
|
||||||||||
const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq; | ||||||||||
uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE]; | ||||||||||
// cache_kv idx | ||||||||||
uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; | ||||||||||
uint32_t block_stride = kv_num_heads * kv_h_stride; | ||||||||||
const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; | ||||||||||
const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; | ||||||||||
|
||||||||||
// k_out v_out idx | ||||||||||
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; | ||||||||||
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前k block起始指针 | ||||||||||
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前v block起始指针 | ||||||||||
|
||||||||||
uint32_t kv_frag[4]; | ||||||||||
T *frag_dq_T = reinterpret_cast<T *>(kv_frag); | ||||||||||
|
||||||||||
constexpr uint32_t num_vecs_per_head = | ||||||||||
HEAD_DIM / num_elems_per_128b<CacheT>(); | ||||||||||
constexpr uint32_t inv_kv_stride = 8 / num_vecs_per_head; | ||||||||||
|
||||||||||
extern __shared__ uint8_t smem[]; | ||||||||||
smem_t k_smem(smem); | ||||||||||
uint32_t k_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>( | ||||||||||
wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp | ||||||||||
|
||||||||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>( | ||||||||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); | ||||||||||
|
||||||||||
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + | ||||||||||
tid % 8 * num_elems_per_128b<CacheT>(); | ||||||||||
|
||||||||||
// load k_smem 行是64 列是128 | ||||||||||
for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行 | ||||||||||
for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16 | ||||||||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>( | ||||||||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); | ||||||||||
k_smem_offset_w = | ||||||||||
k_smem.advance_offset_by_column<8, num_vecs_per_head>(k_smem_offset_w, fy); | ||||||||||
k_read_idx += 8 * num_elems_per_128b<CacheT>(); | ||||||||||
} | ||||||||||
k_smem_offset_w = | ||||||||||
k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_w) - 8; | ||||||||||
k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b<CacheT>(); | ||||||||||
} | ||||||||||
commit_group(); | ||||||||||
wait_group<0>(); | ||||||||||
__syncthreads(); | ||||||||||
|
||||||||||
// deal k_smem 行是64 列是128 | ||||||||||
for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行 | ||||||||||
uint32_t row_idx = wid * 16 + tid / 4; | ||||||||||
for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16) | ||||||||||
uint32_t col_idx = fy * 16 + tid % 4 * 2; | ||||||||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); | ||||||||||
// 存储 | ||||||||||
/*** | ||||||||||
r0c0,r0c1, r0c8,r0c9 | ||||||||||
r8c0,r8c1, r8c8,r8c9 | ||||||||||
***/ | ||||||||||
T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; | ||||||||||
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; | ||||||||||
|
||||||||||
if (row_idx < end_idx) { | ||||||||||
k_tile_ptr0[0] = frag_dq_T[0]; | ||||||||||
k_tile_ptr0[1] = frag_dq_T[1]; | ||||||||||
k_tile_ptr0[8] = frag_dq_T[4]; | ||||||||||
k_tile_ptr0[9] = frag_dq_T[5]; | ||||||||||
} | ||||||||||
|
||||||||||
if (row_idx + 8 < end_idx) { | ||||||||||
k_tile_ptr1[0] = frag_dq_T[2]; | ||||||||||
k_tile_ptr1[1] = frag_dq_T[3]; | ||||||||||
k_tile_ptr0[8] = frag_dq_T[6]; | ||||||||||
k_tile_ptr0[9] = frag_dq_T[7]; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This write is using
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||
} | ||||||||||
k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>( | ||||||||||
k_smem_offset_r, fy); | ||||||||||
} | ||||||||||
k_smem_offset_r = | ||||||||||
k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_r) - 16; | ||||||||||
} | ||||||||||
// ================v================ | ||||||||||
smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); | ||||||||||
uint32_t v_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>( | ||||||||||
wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp | ||||||||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>( | ||||||||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); | ||||||||||
|
||||||||||
uint32_t v_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + | ||||||||||
tid % 8 * num_elems_per_128b<CacheT>(); | ||||||||||
|
||||||||||
// load v_smem 行是64 列是128 | ||||||||||
for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行 | ||||||||||
for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16 | ||||||||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>( | ||||||||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); | ||||||||||
v_smem_offset_w = | ||||||||||
v_smem.advance_offset_by_column<8, num_vecs_per_head>(v_smem_offset_w, fy); | ||||||||||
v_read_idx += 8 * num_elems_per_128b<CacheT>(); | ||||||||||
} | ||||||||||
v_smem_offset_w = | ||||||||||
v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_w) - 8; | ||||||||||
v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b<CacheT>(); | ||||||||||
} | ||||||||||
commit_group(); | ||||||||||
wait_group<0>(); | ||||||||||
__syncthreads(); | ||||||||||
|
||||||||||
// deal v_smem 行是64 列是128 | ||||||||||
|
||||||||||
for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行 | ||||||||||
uint32_t row_idx = wid * 16 + tid / 4; | ||||||||||
for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16) | ||||||||||
uint32_t col_idx = fy * 16 + tid % 4 * 2; | ||||||||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag); | ||||||||||
// 存储 | ||||||||||
/*** | ||||||||||
r0c0,r0c1, r0c8,r0c9 | ||||||||||
r8c0,r8c1, r8c8,r8c9 | ||||||||||
***/ | ||||||||||
T *v_tile_ptr0 = v_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; | ||||||||||
T *v_tile_ptr1 = v_tile_ptr0 + 8 * kv_t_stride; | ||||||||||
|
||||||||||
if (row_idx < end_idx) { | ||||||||||
v_tile_ptr0[0] = frag_dq_T[0]; | ||||||||||
v_tile_ptr0[1] = frag_dq_T[1]; | ||||||||||
v_tile_ptr0[8] = frag_dq_T[4]; | ||||||||||
v_tile_ptr0[9] = frag_dq_T[5]; | ||||||||||
} | ||||||||||
|
||||||||||
if (row_idx + 8 < end_idx) { | ||||||||||
v_tile_ptr1[0] = frag_dq_T[2]; | ||||||||||
v_tile_ptr1[1] = frag_dq_T[3]; | ||||||||||
v_tile_ptr0[8] = frag_dq_T[6]; | ||||||||||
v_tile_ptr0[9] = frag_dq_T[7]; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the K path, this write uses
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||
} | ||||||||||
v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>( | ||||||||||
v_smem_offset_r, fy); | ||||||||||
} | ||||||||||
v_smem_offset_r = | ||||||||||
v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_r) - 16; | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
template <typename T, | ||||||||||
typename CacheT, | ||||||||||
uint32_t HEAD_DIM, | ||||||||||
|
@@ -214,7 +388,7 @@ __global__ void append_dequant_cache_kv_c8( | |||||||||
|
||||||||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>( | ||||||||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); | ||||||||||
|
||||||||||
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + | ||||||||||
tid % 8 * num_elems_per_128b<CacheT>(); | ||||||||||
|
||||||||||
|
@@ -328,7 +502,7 @@ __global__ void append_dequant_cache_kv_c8( | |||||||||
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale; | ||||||||||
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale; | ||||||||||
|
||||||||||
|
||||||||||
convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T | ||||||||||
#ifdef C8_DEBUG | ||||||||||
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) { | ||||||||||
|
@@ -371,14 +545,36 @@ void AppendDequantCache( | |||||||||
paddle::Tensor *k_out, | ||||||||||
paddle::Tensor *v_out, | ||||||||||
const cudaStream_t& stream | ||||||||||
) { | ||||||||||
) { | ||||||||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type; | ||||||||||
if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { | ||||||||||
constexpr int NUM_WARPS = 4; | ||||||||||
int block_num = cache_num_blocks_x.data<int>()[0]; | ||||||||||
dim3 grids(block_num, 1, kv_num_heads); | ||||||||||
dim3 blocks(32, NUM_WARPS); | ||||||||||
|
||||||||||
constexpr int NUM_WARPS = 4; | ||||||||||
int block_num = cache_num_blocks_x.data<int>()[0]; | ||||||||||
dim3 grids(block_num, 1, kv_num_heads); | ||||||||||
dim3 blocks(32, NUM_WARPS); | ||||||||||
if (cache_quant_type == "none") { | ||||||||||
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2; | ||||||||||
auto kernel_func = append_dequant_cache_kv_c16<NV_TYPE, NV_TYPE, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>; | ||||||||||
|
||||||||||
if (smem_size >= 48 * 1024) { | ||||||||||
cudaFuncSetAttribute(kernel_func, | ||||||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, | ||||||||||
smem_size); | ||||||||||
} | ||||||||||
kernel_func<<<grids, blocks, smem_size, stream>>>( | ||||||||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())), | ||||||||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())), | ||||||||||
reinterpret_cast<NV_TYPE *>(k_out->data<T>()), | ||||||||||
reinterpret_cast<NV_TYPE *>(v_out->data<T>()), | ||||||||||
seq_lens_this_time.data<int>(), | ||||||||||
seq_lens_decoder.data<int>(), | ||||||||||
cu_seqlens_k.data<int>(), | ||||||||||
block_tables.data<int>(), | ||||||||||
cache_batch_ids.data<int>(), | ||||||||||
cache_tile_ids_per_batch.data<int>(), | ||||||||||
max_blocks_per_seq, | ||||||||||
kv_num_heads | ||||||||||
); | ||||||||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { | ||||||||||
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2; | ||||||||||
|
||||||||||
auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>; | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition is comparing a pointer
seq_lens_this_time
instead of its value. You likely meant to check the sequence length element, e.g.,seq_lens_this_time[tile_idx] <= 0
orseq_lens_this_time[batch_id] <= 0
.Copilot uses AI. Check for mistakes.