Skip to content

Commit 7bc246b

Browse files
committed
support c16 prompt_cache in fa3
1 parent 0d03403 commit 7bc246b

File tree

1 file changed

+206
-10
lines changed

1 file changed

+206
-10
lines changed

custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu

Lines changed: 206 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "paddle/extension.h"
1717
#include "paddle/phi/core/memory/memcpy.h"
1818
#include "encoder_write_cache_with_rope_impl.cuh"
19-
#include "paddle/phi/kernels/gpu/flash_attn_v3_kernel.h"
2019
#include "paddle/phi/backends/context_pool.h"
2120
#include "remote_cache_kv_ipc.h"
2221

@@ -148,6 +147,181 @@ void gqa_rotary_qk_split_variable(
148147
dim_head);
149148
}
150149

150+
template <typename T,
151+
typename CacheT,
152+
uint32_t HEAD_DIM,
153+
uint32_t BLOCK_SIZE,
154+
uint32_t NUM_WARPS=4>
155+
__global__ void append_dequant_cache_kv_c16(
156+
const T *__restrict__ cache_k,
157+
const T *__restrict__ cache_v,
158+
T *__restrict__ k_out,
159+
T *__restrict__ v_out,
160+
const int *__restrict__ seq_lens_this_time,
161+
const int *__restrict__ seq_lens_decoder,
162+
const int *__restrict__ cu_seqlens_k,
163+
const int *__restrict__ block_tables,
164+
const int *batch_ids,
165+
const int *tile_ids_per_batch,
166+
const int max_blocks_per_seq,
167+
const int kv_num_heads) {
168+
// start_kv_idx: 每个block的起始kv_idx
169+
// batch_id:每个block属于的batch
170+
// TODO: 1.scale预取 2.frag_dq_T复用 3.流水线编排 4.store访存合并 5.cacheT支持(int8/fp8)
171+
const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z;
172+
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
173+
174+
const uint32_t batch_id = batch_ids[tile_idx];
175+
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
176+
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
177+
if (seq_lens_this_time <= 0) {
178+
return;
179+
}
180+
181+
const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq;
182+
uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE];
183+
// cache_kv idx
184+
uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
185+
uint32_t block_stride = kv_num_heads * kv_h_stride;
186+
const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride;
187+
const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride;
188+
189+
// k_out v_out idx
190+
uint32_t kv_t_stride = kv_num_heads * HEAD_DIM;
191+
T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前k block起始指针
192+
T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前v block起始指针
193+
194+
uint32_t kv_frag[4];
195+
T *frag_dq_T = reinterpret_cast<T *>(kv_frag);
196+
197+
constexpr uint32_t num_vecs_per_head =
198+
HEAD_DIM / num_elems_per_128b<CacheT>();
199+
constexpr uint32_t inv_kv_stride = 8 / num_vecs_per_head;
200+
201+
extern __shared__ uint8_t smem[];
202+
smem_t k_smem(smem);
203+
uint32_t k_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
204+
wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp
205+
206+
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
207+
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
208+
209+
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
210+
tid % 8 * num_elems_per_128b<CacheT>();
211+
212+
// load k_smem 行是64 列是128
213+
for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行
214+
for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16
215+
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
216+
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
217+
k_smem_offset_w =
218+
k_smem.advance_offset_by_column<8, num_vecs_per_head>(k_smem_offset_w, fy);
219+
k_read_idx += 8 * num_elems_per_128b<CacheT>();
220+
}
221+
k_smem_offset_w =
222+
k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_w) - 8;
223+
k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b<CacheT>();
224+
}
225+
commit_group();
226+
wait_group<0>();
227+
__syncthreads();
228+
229+
// deal k_smem 行是64 列是128
230+
for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行
231+
uint32_t row_idx = wid * 16 + tid / 4;
232+
for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16)
233+
uint32_t col_idx = fy * 16 + tid % 4 * 2;
234+
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
235+
// 存储
236+
/***
237+
r0c0,r0c1, r0c8,r0c9
238+
r8c0,r8c1, r8c8,r8c9
239+
***/
240+
T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx;
241+
T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride;
242+
243+
if (row_idx < end_idx) {
244+
k_tile_ptr0[0] = frag_dq_T[0];
245+
k_tile_ptr0[1] = frag_dq_T[1];
246+
k_tile_ptr0[8] = frag_dq_T[4];
247+
k_tile_ptr0[9] = frag_dq_T[5];
248+
}
249+
250+
if (row_idx + 8 < end_idx) {
251+
k_tile_ptr1[0] = frag_dq_T[2];
252+
k_tile_ptr1[1] = frag_dq_T[3];
253+
k_tile_ptr0[8] = frag_dq_T[6];
254+
k_tile_ptr0[9] = frag_dq_T[7];
255+
}
256+
k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>(
257+
k_smem_offset_r, fy);
258+
}
259+
k_smem_offset_r =
260+
k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_r) - 16;
261+
}
262+
// ================v================
263+
smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT));
264+
uint32_t v_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
265+
wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp
266+
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>(
267+
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
268+
269+
uint32_t v_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
270+
tid % 8 * num_elems_per_128b<CacheT>();
271+
272+
// load v_smem 行是64 列是128
273+
for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行
274+
for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16
275+
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
276+
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
277+
v_smem_offset_w =
278+
v_smem.advance_offset_by_column<8, num_vecs_per_head>(v_smem_offset_w, fy);
279+
v_read_idx += 8 * num_elems_per_128b<CacheT>();
280+
}
281+
v_smem_offset_w =
282+
v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_w) - 8;
283+
v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b<CacheT>();
284+
}
285+
commit_group();
286+
wait_group<0>();
287+
__syncthreads();
288+
289+
// deal v_smem 行是64 列是128
290+
291+
for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行
292+
uint32_t row_idx = wid * 16 + tid / 4;
293+
for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16)
294+
uint32_t col_idx = fy * 16 + tid % 4 * 2;
295+
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
296+
// 存储
297+
/***
298+
r0c0,r0c1, r0c8,r0c9
299+
r8c0,r8c1, r8c8,r8c9
300+
***/
301+
T *v_tile_ptr0 = v_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx;
302+
T *v_tile_ptr1 = v_tile_ptr0 + 8 * kv_t_stride;
303+
304+
if (row_idx < end_idx) {
305+
v_tile_ptr0[0] = frag_dq_T[0];
306+
v_tile_ptr0[1] = frag_dq_T[1];
307+
v_tile_ptr0[8] = frag_dq_T[4];
308+
v_tile_ptr0[9] = frag_dq_T[5];
309+
}
310+
311+
if (row_idx + 8 < end_idx) {
312+
v_tile_ptr1[0] = frag_dq_T[2];
313+
v_tile_ptr1[1] = frag_dq_T[3];
314+
v_tile_ptr0[8] = frag_dq_T[6];
315+
v_tile_ptr0[9] = frag_dq_T[7];
316+
}
317+
v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>(
318+
v_smem_offset_r, fy);
319+
}
320+
v_smem_offset_r =
321+
v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_r) - 16;
322+
}
323+
}
324+
151325
template <typename T,
152326
typename CacheT,
153327
uint32_t HEAD_DIM,
@@ -214,7 +388,7 @@ __global__ void append_dequant_cache_kv_c8(
214388

215389
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
216390
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
217-
391+
218392
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
219393
tid % 8 * num_elems_per_128b<CacheT>();
220394

@@ -328,7 +502,7 @@ __global__ void append_dequant_cache_kv_c8(
328502
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
329503
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
330504

331-
505+
332506
convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
333507
#ifdef C8_DEBUG
334508
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
@@ -371,14 +545,36 @@ void AppendDequantCache(
371545
paddle::Tensor *k_out,
372546
paddle::Tensor *v_out,
373547
const cudaStream_t& stream
374-
) {
548+
) {
375549
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
376-
if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
377-
constexpr int NUM_WARPS = 4;
378-
int block_num = cache_num_blocks_x.data<int>()[0];
379-
dim3 grids(block_num, 1, kv_num_heads);
380-
dim3 blocks(32, NUM_WARPS);
381-
550+
constexpr int NUM_WARPS = 4;
551+
int block_num = cache_num_blocks_x.data<int>()[0];
552+
dim3 grids(block_num, 1, kv_num_heads);
553+
dim3 blocks(32, NUM_WARPS);
554+
if (cache_quant_type == "none") {
555+
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2;
556+
auto kernel_func = append_dequant_cache_kv_c16<NV_TYPE, NV_TYPE, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>;
557+
558+
if (smem_size >= 48 * 1024) {
559+
cudaFuncSetAttribute(kernel_func,
560+
cudaFuncAttributeMaxDynamicSharedMemorySize,
561+
smem_size);
562+
}
563+
kernel_func<<<grids, blocks, smem_size, stream>>>(
564+
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
565+
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
566+
reinterpret_cast<NV_TYPE *>(k_out->data<T>()),
567+
reinterpret_cast<NV_TYPE *>(v_out->data<T>()),
568+
seq_lens_this_time.data<int>(),
569+
seq_lens_decoder.data<int>(),
570+
cu_seqlens_k.data<int>(),
571+
block_tables.data<int>(),
572+
cache_batch_ids.data<int>(),
573+
cache_tile_ids_per_batch.data<int>(),
574+
max_blocks_per_seq,
575+
kv_num_heads
576+
);
577+
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
382578
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
383579

384580
auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;

0 commit comments

Comments
 (0)