From 7bc246b6f79f365bb03d5a6c974709979635811d Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Wed, 9 Jul 2025 11:40:31 +0800 Subject: [PATCH 1/3] support c16 prompt_cache in fa3 --- .../append_attn/gqa_rope_write_cache.cu | 216 +++++++++++++++++- 1 file changed, 206 insertions(+), 10 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 2f3b339009..1cd8581a79 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -16,7 +16,6 @@ #include "paddle/extension.h" #include "paddle/phi/core/memory/memcpy.h" #include "encoder_write_cache_with_rope_impl.cuh" -#include "paddle/phi/kernels/gpu/flash_attn_v3_kernel.h" #include "paddle/phi/backends/context_pool.h" #include "remote_cache_kv_ipc.h" @@ -148,6 +147,181 @@ void gqa_rotary_qk_split_variable( dim_head); } +template +__global__ void append_dequant_cache_kv_c16( + const T *__restrict__ cache_k, + const T *__restrict__ cache_v, + T *__restrict__ k_out, + T *__restrict__ v_out, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ cu_seqlens_k, + const int *__restrict__ block_tables, + const int *batch_ids, + const int *tile_ids_per_batch, + const int max_blocks_per_seq, + const int kv_num_heads) { + // start_kv_idx: 每个block的起始kv_idx + // batch_id:每个block属于的batch + // TODO: 1.scale预取 2.frag_dq_T复用 3.流水线编排 4.store访存合并 5.cacheT支持(int8/fp8) + const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + const uint32_t batch_id = batch_ids[tile_idx]; + const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE; + const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx; + if (seq_lens_this_time <= 0) { + return; + } + + const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq; + uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE]; + // cache_kv idx + uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + uint32_t block_stride = kv_num_heads * kv_h_stride; + const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; + + // k_out v_out idx + uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; + T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前k block起始指针 + T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前v block起始指针 + + uint32_t kv_frag[4]; + T *frag_dq_T = reinterpret_cast(kv_frag); + + constexpr uint32_t num_vecs_per_head = + HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t inv_kv_stride = 8 / num_vecs_per_head; + + extern __shared__ uint8_t smem[]; + smem_t k_smem(smem); + uint32_t k_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp + + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + + // load k_smem 行是64 列是128 + for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行 + for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16 + k_smem.load_128b_async( + k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); + k_smem_offset_w = + k_smem.advance_offset_by_column<8, num_vecs_per_head>(k_smem_offset_w, fy); + k_read_idx += 8 * num_elems_per_128b(); + } + k_smem_offset_w = + k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_w) - 8; + k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b(); + } + commit_group(); + wait_group<0>(); + __syncthreads(); + + // deal k_smem 行是64 列是128 + for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行 + uint32_t row_idx = wid * 16 + tid / 4; + for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16) + uint32_t col_idx = fy * 16 + tid % 4 * 2; + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); + // 存储 + /*** + r0c0,r0c1, r0c8,r0c9 + r8c0,r8c1, r8c8,r8c9 + ***/ + T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; + T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; + + if (row_idx < end_idx) { + k_tile_ptr0[0] = frag_dq_T[0]; + k_tile_ptr0[1] = frag_dq_T[1]; + k_tile_ptr0[8] = frag_dq_T[4]; + k_tile_ptr0[9] = frag_dq_T[5]; + } + + if (row_idx + 8 < end_idx) { + k_tile_ptr1[0] = frag_dq_T[2]; + k_tile_ptr1[1] = frag_dq_T[3]; + k_tile_ptr0[8] = frag_dq_T[6]; + k_tile_ptr0[9] = frag_dq_T[7]; + } + k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>( + k_smem_offset_r, fy); + } + k_smem_offset_r = + k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_r) - 16; + } + // ================v================ + smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); + uint32_t v_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + + // load v_smem 行是64 列是128 + for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行 + for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16 + v_smem.load_128b_async( + v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); + v_smem_offset_w = + v_smem.advance_offset_by_column<8, num_vecs_per_head>(v_smem_offset_w, fy); + v_read_idx += 8 * num_elems_per_128b(); + } + v_smem_offset_w = + v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_w) - 8; + v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b(); + } + commit_group(); + wait_group<0>(); + __syncthreads(); + + // deal v_smem 行是64 列是128 + + for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行 + uint32_t row_idx = wid * 16 + tid / 4; + for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16) + uint32_t col_idx = fy * 16 + tid % 4 * 2; + v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag); + // 存储 + /*** + r0c0,r0c1, r0c8,r0c9 + r8c0,r8c1, r8c8,r8c9 + ***/ + T *v_tile_ptr0 = v_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; + T *v_tile_ptr1 = v_tile_ptr0 + 8 * kv_t_stride; + + if (row_idx < end_idx) { + v_tile_ptr0[0] = frag_dq_T[0]; + v_tile_ptr0[1] = frag_dq_T[1]; + v_tile_ptr0[8] = frag_dq_T[4]; + v_tile_ptr0[9] = frag_dq_T[5]; + } + + if (row_idx + 8 < end_idx) { + v_tile_ptr1[0] = frag_dq_T[2]; + v_tile_ptr1[1] = frag_dq_T[3]; + v_tile_ptr0[8] = frag_dq_T[6]; + v_tile_ptr0[9] = frag_dq_T[7]; + } + v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>( + v_smem_offset_r, fy); + } + v_smem_offset_r = + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_r) - 16; + } +} + template ( wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - + uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + tid % 8 * num_elems_per_128b(); @@ -328,7 +502,7 @@ __global__ void append_dequant_cache_kv_c8( v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale; v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale; - + convert_c8(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T #ifdef C8_DEBUG if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) { @@ -371,14 +545,36 @@ void AppendDequantCache( paddle::Tensor *k_out, paddle::Tensor *v_out, const cudaStream_t& stream -) { +) { using NV_TYPE = typename cascade_attn_type_traits::type; - if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { - constexpr int NUM_WARPS = 4; - int block_num = cache_num_blocks_x.data()[0]; - dim3 grids(block_num, 1, kv_num_heads); - dim3 blocks(32, NUM_WARPS); - + constexpr int NUM_WARPS = 4; + int block_num = cache_num_blocks_x.data()[0]; + dim3 grids(block_num, 1, kv_num_heads); + dim3 blocks(32, NUM_WARPS); + if (cache_quant_type == "none") { + const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2; + auto kernel_func = append_dequant_cache_kv_c16; + + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(kernel_func, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + kernel_func<<>>( + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + reinterpret_cast(k_out->data()), + reinterpret_cast(v_out->data()), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + cu_seqlens_k.data(), + block_tables.data(), + cache_batch_ids.data(), + cache_tile_ids_per_batch.data(), + max_blocks_per_seq, + kv_num_heads + ); + } else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2; auto kernel_func = append_dequant_cache_kv_c8; From fb577407c8b2f2c6cfce51fd5ca58cdfa09a6b9f Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Wed, 9 Jul 2025 13:00:31 +0800 Subject: [PATCH 2/3] fix prefix_cache in fa3 --- .../append_attn/gqa_rope_write_cache.cu | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 1cd8581a79..b6f5b8f850 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -152,7 +152,7 @@ template -__global__ void append_dequant_cache_kv_c16( +__global__ void append_cache_kv_c16( const T *__restrict__ cache_k, const T *__restrict__ cache_v, T *__restrict__ k_out, @@ -174,7 +174,7 @@ __global__ void append_dequant_cache_kv_c16( const uint32_t batch_id = batch_ids[tile_idx]; const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE; const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx; - if (seq_lens_this_time <= 0) { + if (seq_lens_this_time[batch_id] <= 0) { return; } @@ -250,8 +250,8 @@ __global__ void append_dequant_cache_kv_c16( if (row_idx + 8 < end_idx) { k_tile_ptr1[0] = frag_dq_T[2]; k_tile_ptr1[1] = frag_dq_T[3]; - k_tile_ptr0[8] = frag_dq_T[6]; - k_tile_ptr0[9] = frag_dq_T[7]; + k_tile_ptr1[8] = frag_dq_T[6]; + k_tile_ptr1[9] = frag_dq_T[7]; } k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>( k_smem_offset_r, fy); @@ -311,8 +311,8 @@ __global__ void append_dequant_cache_kv_c16( if (row_idx + 8 < end_idx) { v_tile_ptr1[0] = frag_dq_T[2]; v_tile_ptr1[1] = frag_dq_T[3]; - v_tile_ptr0[8] = frag_dq_T[6]; - v_tile_ptr0[9] = frag_dq_T[7]; + v_tile_ptr1[8] = frag_dq_T[6]; + v_tile_ptr1[9] = frag_dq_T[7]; } v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>( v_smem_offset_r, fy); @@ -328,7 +328,7 @@ template -__global__ void append_dequant_cache_kv_c8( +__global__ void append_cache_kv_c8( const CacheT *__restrict__ cache_k, const CacheT *__restrict__ cache_v, T *__restrict__ k_out, @@ -527,7 +527,7 @@ __global__ void append_dequant_cache_kv_c8( } template -void AppendDequantCache( +void AppendCacheKV( const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, const paddle::Tensor &cache_k_dequant_scales, @@ -553,7 +553,7 @@ void AppendDequantCache( dim3 blocks(32, NUM_WARPS); if (cache_quant_type == "none") { const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2; - auto kernel_func = append_dequant_cache_kv_c16; + auto kernel_func = append_cache_kv_c16; if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(kernel_func, @@ -577,9 +577,9 @@ void AppendDequantCache( } else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2; - auto kernel_func = append_dequant_cache_kv_c8; + auto kernel_func = append_cache_kv_c8; if (cache_quant_type == "cache_fp8") { - kernel_func = append_dequant_cache_kv_c8; + kernel_func = append_cache_kv_c8; } if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(kernel_func, @@ -757,7 +757,7 @@ std::vector GQARopeWriteCacheKernel( } if (token_num < kv_token_num) { - AppendDequantCache( + AppendCacheKV( key_cache, value_cache, cache_k_dequant_scales.get(), From 0daa67b3eafa6a4cd3598d83d3775dd116faa466 Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Wed, 16 Jul 2025 13:06:50 +0800 Subject: [PATCH 3/3] support c4 prompt_cache in fa3 --- .../append_attn/gqa_rope_write_cache.cu | 410 +++++++++++++++--- 1 file changed, 338 insertions(+), 72 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index b6f5b8f850..a2d4da2862 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -165,9 +165,9 @@ __global__ void append_cache_kv_c16( const int *tile_ids_per_batch, const int max_blocks_per_seq, const int kv_num_heads) { - // start_kv_idx: 每个block的起始kv_idx - // batch_id:每个block属于的batch - // TODO: 1.scale预取 2.frag_dq_T复用 3.流水线编排 4.store访存合并 5.cacheT支持(int8/fp8) + // start_kv_idx: start kv_idx current block + // batch_id:block's batch_id + // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8) const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t tid = threadIdx.x, wid = threadIdx.y; @@ -188,8 +188,8 @@ __global__ void append_cache_kv_c16( // k_out v_out idx uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; - T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前k block起始指针 - T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前v block起始指针 + T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; uint32_t kv_frag[4]; T *frag_dq_T = reinterpret_cast(kv_frag); @@ -209,9 +209,9 @@ __global__ void append_cache_kv_c16( uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + tid % 8 * num_elems_per_128b(); - // load k_smem 行是64 列是128 - for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行 - for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16 + // load k_smem 64 rows 128 cols + for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter + for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter k_smem.load_128b_async( k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); k_smem_offset_w = @@ -219,20 +219,20 @@ __global__ void append_cache_kv_c16( k_read_idx += 8 * num_elems_per_128b(); } k_smem_offset_w = - k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_w) - 8; - k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b(); + k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_w) - 16; + k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 16 * num_elems_per_128b(); } commit_group(); wait_group<0>(); __syncthreads(); - // deal k_smem 行是64 列是128 - for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行 + // deal k_smem 64 rows 128 cols + for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter uint32_t row_idx = wid * 16 + tid / 4; - for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16) + for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter uint32_t col_idx = fy * 16 + tid % 4 * 2; k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); - // 存储 + // layout /*** r0c0,r0c1, r0c8,r0c9 r8c0,r8c1, r8c8,r8c9 @@ -243,13 +243,13 @@ __global__ void append_cache_kv_c16( if (row_idx < end_idx) { k_tile_ptr0[0] = frag_dq_T[0]; k_tile_ptr0[1] = frag_dq_T[1]; - k_tile_ptr0[8] = frag_dq_T[4]; - k_tile_ptr0[9] = frag_dq_T[5]; + k_tile_ptr0[8] = frag_dq_T[2]; + k_tile_ptr0[9] = frag_dq_T[3]; } if (row_idx + 8 < end_idx) { - k_tile_ptr1[0] = frag_dq_T[2]; - k_tile_ptr1[1] = frag_dq_T[3]; + k_tile_ptr1[0] = frag_dq_T[4]; + k_tile_ptr1[1] = frag_dq_T[5]; k_tile_ptr1[8] = frag_dq_T[6]; k_tile_ptr1[9] = frag_dq_T[7]; } @@ -259,6 +259,7 @@ __global__ void append_cache_kv_c16( k_smem_offset_r = k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_r) - 16; } + // ================v================ smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); uint32_t v_smem_offset_w = smem_t::get_permuted_offset( @@ -269,9 +270,9 @@ __global__ void append_cache_kv_c16( uint32_t v_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + tid % 8 * num_elems_per_128b(); - // load v_smem 行是64 列是128 - for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行 - for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16 + // load v_smem 64 rows 128 cols + for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter + for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter v_smem.load_128b_async( v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); v_smem_offset_w = @@ -279,21 +280,20 @@ __global__ void append_cache_kv_c16( v_read_idx += 8 * num_elems_per_128b(); } v_smem_offset_w = - v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_w) - 8; - v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b(); + v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_w) - 16; + v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 16 * num_elems_per_128b(); } commit_group(); wait_group<0>(); __syncthreads(); - // deal v_smem 行是64 列是128 - - for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行 + // deal v_smem 64 rows 128 cols + for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter uint32_t row_idx = wid * 16 + tid / 4; - for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16) + for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter uint32_t col_idx = fy * 16 + tid % 4 * 2; v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag); - // 存储 + // layout /*** r0c0,r0c1, r0c8,r0c9 r8c0,r8c1, r8c8,r8c9 @@ -304,13 +304,13 @@ __global__ void append_cache_kv_c16( if (row_idx < end_idx) { v_tile_ptr0[0] = frag_dq_T[0]; v_tile_ptr0[1] = frag_dq_T[1]; - v_tile_ptr0[8] = frag_dq_T[4]; - v_tile_ptr0[9] = frag_dq_T[5]; + v_tile_ptr0[8] = frag_dq_T[2]; + v_tile_ptr0[9] = frag_dq_T[3]; } if (row_idx + 8 < end_idx) { - v_tile_ptr1[0] = frag_dq_T[2]; - v_tile_ptr1[1] = frag_dq_T[3]; + v_tile_ptr1[0] = frag_dq_T[4]; + v_tile_ptr1[1] = frag_dq_T[5]; v_tile_ptr1[8] = frag_dq_T[6]; v_tile_ptr1[9] = frag_dq_T[7]; } @@ -343,9 +343,9 @@ __global__ void append_cache_kv_c8( const int *tile_ids_per_batch, const int max_blocks_per_seq, const int kv_num_heads) { - // start_kv_idx: 每个block的起始kv_idx - // batch_id:每个block属于的batch - // TODO: 1.scale预取 2.frag_dq_T复用 3.流水线编排 4.store访存合并 5.cacheT支持(int8/fp8) + // start_kv_idx: start kv_idx current block + // batch_id:block's batch_id + // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8) const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t tid = threadIdx.x, wid = threadIdx.y; @@ -366,8 +366,8 @@ __global__ void append_cache_kv_c8( // k_out v_out idx uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; - T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前k block起始指针 - T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前v block起始指针 + T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; uint32_t k_frag[4], v_frag[4], frag_dq[4]; T *frag_dq_T = reinterpret_cast(frag_dq); @@ -392,9 +392,9 @@ __global__ void append_cache_kv_c8( uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + tid % 8 * num_elems_per_128b(); - // load k_smem 行是64 列是128 - for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行 - for (int fy = 0; fy < 1; fy++) { // 一次8个128b = 128个uint8 + // load v_smem 64 rows, 128 cols + for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter + for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter k_smem.load_128b_async( k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); k_smem_offset_w = @@ -409,13 +409,13 @@ __global__ void append_cache_kv_c8( wait_group<0>(); __syncthreads(); - // deal k_smem 行是64 列是128 - for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行 + // deal k_smem 64 rows, 128 cols + for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter uint32_t row_idx = wid * 16 + tid / 4; - for (int fy = 0; fy < 4; fy++) { // 1次2个128b(32个uint8),4次循环8个128b(128个uint8) + for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter uint32_t col_idx = fy * 32 + tid % 4 * 2; k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag); - // 反量化 存储 + // layout /*** r0c0,r0c1,r0c8,r0c9, r8c0,r8c1,r8c8,r8c9 r0c16,r0c17,r0c24,r0c25, r8c16,r8c17,r8c24,r8c25 @@ -425,7 +425,7 @@ __global__ void append_cache_kv_c8( T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; if (row_idx < end_idx) { - convert_c8(frag_dq_T,k_frag[2 * i]); // 4个uint8/fp8 -> 4个T + convert_c8(frag_dq_T,k_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale; k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale; @@ -434,7 +434,7 @@ __global__ void append_cache_kv_c8( } if (row_idx + 8 < end_idx) { - convert_c8(frag_dq_T + 4,k_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T + convert_c8(frag_dq_T + 4,k_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale; k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale; @@ -449,8 +449,8 @@ __global__ void append_cache_kv_c8( k_smem_offset_r = k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_r) - 8; } - // ================v================ + // ================v================ smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); uint32_t v_smem_offset_w = smem_t::get_permuted_offset( wid * 8 + tid / 4, tid % 4); // 4 * 8 per warp @@ -460,9 +460,9 @@ __global__ void append_cache_kv_c8( uint32_t v_read_idx = (wid * 8 + tid / 4) * BLOCK_SIZE + tid % 4 * num_elems_per_128b(); - // load v_smem 行是128 列是64 - for (int fy = 0; fy < 4; fy++) { // 每个warp1次8行,循环4次32行,4个warp128行 - for (int fz = 0; fz < 1; fz++) { // 一次4个128b = 64个uint8 + // load v_smem 128 rows 64 cols + for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter + for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter v_smem.load_128b_async( v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); v_smem_offset_w = @@ -478,39 +478,24 @@ __global__ void append_cache_kv_c8( wait_group<0>(); __syncthreads(); - // deal v_smem 行是128 列是64 row_idx是head_dim, col_idx是block_size - for (int fy = 0; fy < 2; fy++) { // 每个warp1次16行,循环2次32行,4个warp128行 + // deal v_smem 128 rows 64 cols + for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4; - for (int fz = 0; fz < 2; fz++) { // 1次2个128b(32个uint8),2次循环4个128b(64个uint8) + for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter uint32_t kv_idx = fz * 32 + tid % 4 * 2; v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag); - // 反量化 存储 + // layout for (int i = 0; i < 4 / 2; i++) { T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + kv_head_idx * HEAD_DIM + dim_idx; T *v_tile_ptr1 = v_tile_ptr0 + 8; if (kv_idx < end_idx) { - convert_c8(frag_dq_T, v_frag[2 * i]); // 4个uint8/fp8 -> 4个T -#ifdef C8_DEBUG - if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) { - printf("1.fy: %d, fz:%d, row_idx: %d, col_idx: %d, v_frag: %.f, %.f, %.f, %.f \n", - fy, fz, kv_idx, dim_idx, static_cast(frag_dq_T[0]), static_cast(frag_dq_T[1]), - static_cast(frag_dq_T[2]), static_cast(frag_dq_T[3])); - } -#endif + convert_c8(frag_dq_T, v_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale; v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale; v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale; v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale; - - convert_c8(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T -#ifdef C8_DEBUG - if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) { - printf("2.fy: %d, fz:%d, row_idx: %d, col_idx: %d, v_frag: %.f, %.f, %.f, %.f \n", - fy, fz, kv_idx, dim_idx + 8, static_cast(frag_dq_T[4]), static_cast(frag_dq_T[5]), - static_cast(frag_dq_T[6]), static_cast(frag_dq_T[7])); - } -#endif + convert_c8(frag_dq_T + 4, v_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale; v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale; v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale; @@ -526,12 +511,237 @@ __global__ void append_cache_kv_c8( } } +template +__global__ void append_cache_kv_c4( + const CacheT *__restrict__ cache_k, + const CacheT *__restrict__ cache_v, + T *__restrict__ k_out, + T *__restrict__ v_out, + const T *__restrict__ cache_k_dequant_scales, + const T *__restrict__ cache_v_dequant_scales, + const T *__restrict__ cache_k_zero_point, + const T *__restrict__ cache_v_zero_point, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ cu_seqlens_k, + const int *__restrict__ block_tables, + const int *batch_ids, + const int *tile_ids_per_batch, + const int max_blocks_per_seq, + const int kv_num_heads) { + // start_kv_idx: start kv_idx current block + // batch_id:block's batch_id + // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8) + const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + const uint32_t batch_id = batch_ids[tile_idx]; + const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE; + const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx; + if (seq_lens_this_time <= 0) { + return; + } + + const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq; + uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE]; + if (block_id < 0) block_id = 0; + + constexpr uint32_t HEAD_DIM_HALF = HEAD_DIM / 2; + constexpr uint32_t BLOCK_SIZE_HALF = BLOCK_SIZE / 2; + // cache_kv idx + uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM_HALF; + uint32_t block_stride = kv_num_heads * kv_h_stride; + const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; + + // k_out v_out idx + uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; + T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + + extern __shared__ uint8_t smem[]; + + uint32_t k_frag[4], v_frag[4], frag_dq[8]; + T *frag_dq_T = reinterpret_cast(frag_dq); + + // load dequant scales and zero points + const T *cache_k_scale_now = cache_k_dequant_scales + kv_head_idx * HEAD_DIM; + const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_now = cache_v_dequant_scales + kv_head_idx * HEAD_DIM; + const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; + T *cache_k_scale_smem = reinterpret_cast( + smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); + T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; + T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; + T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; +#pragma unroll + for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { + cache_k_scale_smem[i] = cache_k_scale_now[i]; + cache_k_zero_point_smem[i] = cache_k_zp_now[i] - static_cast(136.f); + cache_v_scale_smem[i] = cache_v_scale_now[i]; + cache_v_zero_point_smem[i] = cache_v_zp_now[i] - static_cast(136.f); + } + + smem_t k_smem(smem); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM_HALF / num_elems_per_128b(); // 2 + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE_HALF / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; // 4 + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + + uint32_t k_smem_offset_w = smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); // 2(iter) * 4(warp) * 8 row per warp + + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); // + + uint32_t k_read_idx = (wid * 8 + tid / 4) * HEAD_DIM / 2 + + tid % 4 * num_elems_per_128b(); + + // load k_smem 64 rows 128 cols + for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter + for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter + k_smem.load_128b_async( + k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); + k_smem_offset_w = + k_smem.advance_offset_by_column<4, num_vecs_per_head_k>(k_smem_offset_w, fy); + k_read_idx += 4 * num_elems_per_128b(); + } + k_smem_offset_w = + k_smem.advance_offset_by_row<8 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_w) - 4; + k_read_idx += 8 * NUM_WARPS * HEAD_DIM / 2 - 4 * num_elems_per_128b(); + } + commit_group(); + wait_group<0>(); + __syncthreads(); + + // deal k_smem 64 rows 128 cols + for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter + uint32_t row_idx = wid * 16 + tid / 4; + for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter + uint32_t col_idx = fy * 64 + tid % 4 * 2; + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag); + + + for (int i = 0; i < 2; i++) { + T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; + T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; + convert_int4(frag_dq_T, k_frag[2 * i]); + convert_int4(frag_dq_T + 8, k_frag[2 * i + 1]); + + if (row_idx < end_idx) { + k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx]; + k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1]; + k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8]; + k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9]; + k_tile_ptr0[16] = frag_dq_T[8] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16]; + k_tile_ptr0[17] = frag_dq_T[9] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17]; + k_tile_ptr0[24] = frag_dq_T[10] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24]; + k_tile_ptr0[25] = frag_dq_T[11] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25]; + } + + if (row_idx + 8 < end_idx) { + k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx]; + k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1]; + k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8]; + k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9]; + k_tile_ptr1[16] = frag_dq_T[12] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16]; + k_tile_ptr1[17] = frag_dq_T[13] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17]; + k_tile_ptr1[24] = frag_dq_T[14] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24]; + k_tile_ptr1[25] = frag_dq_T[15] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25]; + } + col_idx += 32; + } + k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head_k>( + k_smem_offset_r, fy); + } + k_smem_offset_r = + k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_r) - 4; + } + + // ================v================ + smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT) / 2); + uint32_t v_smem_offset_w = smem_t::get_permuted_offset( + wid * 16 + tid / 2, tid % 2); // 4 * 8 per warp + + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_read_idx = (wid * 16 + tid / 2) * BLOCK_SIZE_HALF + + tid % 2 * num_elems_per_128b(); + // load v_smem 128 rows 64 rows + for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter + for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter + v_smem.load_128b_async( + v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); + v_smem_offset_w = + v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>(v_smem_offset_w, fz); + v_read_idx += 2 * num_elems_per_128b(); + } + v_smem_offset_w = + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_w) - 2; + v_read_idx += 16 * NUM_WARPS * BLOCK_SIZE_HALF - 2 * num_elems_per_128b(); + } + + commit_group(); + wait_group<0>(); + __syncthreads(); + + // deal v_smem 128 rows 64 cols + for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter + uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4; + for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter + uint32_t kv_idx = fz * 64 + tid % 4 * 2; + v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag); + // layout + for (int i = 0; i < 2; i++) { + T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + kv_head_idx * HEAD_DIM + dim_idx; + T *v_tile_ptr1 = v_tile_ptr0 + 8; + if (kv_idx < end_idx) { + convert_int4(frag_dq_T, v_frag[2 * i]); + convert_int4(frag_dq_T + 8, v_frag[2 * i + 1]); + + v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; + v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; + v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; + v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; + v_tile_ptr0[16 * kv_t_stride] = frag_dq_T[8] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; + v_tile_ptr0[17 * kv_t_stride] = frag_dq_T[9] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; + v_tile_ptr0[24 * kv_t_stride] = frag_dq_T[10] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; + v_tile_ptr0[25 * kv_t_stride] = frag_dq_T[11] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx]; + + v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr1[16 * kv_t_stride] = frag_dq_T[12] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr1[17 * kv_t_stride] = frag_dq_T[13] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr1[24 * kv_t_stride] = frag_dq_T[14] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + v_tile_ptr1[25 * kv_t_stride] = frag_dq_T[15] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8]; + } + kv_idx += 32; + } + v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>( + v_smem_offset_r, fz); + } + v_smem_offset_r = + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_r) - 2; + } +} + template void AppendCacheKV( const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, const paddle::Tensor &cache_k_dequant_scales, const paddle::Tensor &cache_v_dequant_scales, + const paddle::Tensor &cache_k_zp, + const paddle::Tensor &cache_v_zp, const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &cu_seqlens_k, @@ -602,6 +812,34 @@ void AppendCacheKV( max_blocks_per_seq, kv_num_heads ); + } else if (cache_quant_type == "cache_int4_zp") { + const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2; + + auto kernel_func = append_cache_kv_c4; + + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(kernel_func, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + kernel_func<<>>( + cache_k.data(), + cache_v.data(), + reinterpret_cast(k_out->data()), + reinterpret_cast(v_out->data()), + reinterpret_cast(const_cast(cache_k_dequant_scales.data())), + reinterpret_cast(const_cast(cache_v_dequant_scales.data())), + reinterpret_cast(const_cast(cache_k_zp.data())), + reinterpret_cast(const_cast(cache_v_zp.data())), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + cu_seqlens_k.data(), + block_tables.data(), + cache_batch_ids.data(), + cache_tile_ids_per_batch.data(), + max_blocks_per_seq, + kv_num_heads + ); } else { PADDLE_THROW("%s mode isn't implemented yet", cache_quant_type.c_str()); } @@ -648,7 +886,7 @@ std::vector GQARopeWriteCacheKernel( const int block_size = key_cache.dims()[2]; const int batch_size = cum_offsets.dims()[0]; const int kv_num_heads = key_cache_dims[1]; - const int head_dim = key_cache_dims[3]; + const int head_dim = cache_quant_type == "cache_int4_zp" ? key_cache_dims[3] * 2 : key_cache_dims[3]; const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads; const float softmax_scale = 1.f / sqrt(head_dim); @@ -735,6 +973,32 @@ std::vector GQARopeWriteCacheKernel( stream, const_cast(&key_cache), const_cast(&value_cache)); + } else if (cache_quant_type == "cache_int4_zp") { + CascadeAppendWriteCacheKVC4QKV( + meta_data, + *const_cast(&key_cache), + *const_cast(&value_cache), + qkv_out, + cache_k_quant_scales.get(), + cache_v_quant_scales.get(), + cache_k_zp.get(), + cache_v_zp.get(), + seq_lens_this_time, + seq_lens_decoder, + padding_offsets, + cum_offsets, + block_tables, + kv_batch_ids, + kv_tile_ids, + kv_num_blocks_data, + max_seq_len, + stream, + const_cast(&key_cache), + const_cast(&value_cache)); + } else { + PD_THROW( + "cache_quant_type_str should be one of [none, cache_int8, cache_fp8, " + "cache_int4_zp]"); } const char* fmt_write_cache_completed_signal_str = std::getenv("FLAGS_fmt_write_cache_completed_signal"); const char* FLAGS_use_pd_disaggregation_per_chunk = std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); @@ -762,6 +1026,8 @@ std::vector GQARopeWriteCacheKernel( value_cache, cache_k_dequant_scales.get(), cache_v_dequant_scales.get(), + cache_k_zp.get(), + cache_v_zp.get(), seq_lens_this_time, seq_lens_decoder, cu_seqlens_k,