|
16 | 16 | #include "paddle/extension.h"
|
17 | 17 | #include "paddle/phi/core/memory/memcpy.h"
|
18 | 18 | #include "encoder_write_cache_with_rope_impl.cuh"
|
19 |
| -#include "paddle/phi/kernels/gpu/flash_attn_v3_kernel.h" |
20 | 19 | #include "paddle/phi/backends/context_pool.h"
|
21 | 20 | #include "remote_cache_kv_ipc.h"
|
22 | 21 |
|
@@ -148,6 +147,181 @@ void gqa_rotary_qk_split_variable(
|
148 | 147 | dim_head);
|
149 | 148 | }
|
150 | 149 |
|
| 150 | +template <typename T, |
| 151 | + typename CacheT, |
| 152 | + uint32_t HEAD_DIM, |
| 153 | + uint32_t BLOCK_SIZE, |
| 154 | + uint32_t NUM_WARPS=4> |
| 155 | +__global__ void append_dequant_cache_kv_c16( |
| 156 | + const T *__restrict__ cache_k, |
| 157 | + const T *__restrict__ cache_v, |
| 158 | + T *__restrict__ k_out, |
| 159 | + T *__restrict__ v_out, |
| 160 | + const int *__restrict__ seq_lens_this_time, |
| 161 | + const int *__restrict__ seq_lens_decoder, |
| 162 | + const int *__restrict__ cu_seqlens_k, |
| 163 | + const int *__restrict__ block_tables, |
| 164 | + const int *batch_ids, |
| 165 | + const int *tile_ids_per_batch, |
| 166 | + const int max_blocks_per_seq, |
| 167 | + const int kv_num_heads) { |
| 168 | + // start_kv_idx: 每个block的起始kv_idx |
| 169 | + // batch_id:每个block属于的batch |
| 170 | + // TODO: 1.scale预取 2.frag_dq_T复用 3.流水线编排 4.store访存合并 5.cacheT支持(int8/fp8) |
| 171 | + const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; |
| 172 | + const uint32_t tid = threadIdx.x, wid = threadIdx.y; |
| 173 | + |
| 174 | + const uint32_t batch_id = batch_ids[tile_idx]; |
| 175 | + const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE; |
| 176 | + const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx; |
| 177 | + if (seq_lens_this_time <= 0) { |
| 178 | + return; |
| 179 | + } |
| 180 | + |
| 181 | + const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq; |
| 182 | + uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE]; |
| 183 | + // cache_kv idx |
| 184 | + uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; |
| 185 | + uint32_t block_stride = kv_num_heads * kv_h_stride; |
| 186 | + const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; |
| 187 | + const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; |
| 188 | + |
| 189 | + // k_out v_out idx |
| 190 | + uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; |
| 191 | + T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前k block起始指针 |
| 192 | + T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前v block起始指针 |
| 193 | + |
| 194 | + uint32_t kv_frag[4]; |
| 195 | + T *frag_dq_T = reinterpret_cast<T *>(kv_frag); |
| 196 | + |
| 197 | + constexpr uint32_t num_vecs_per_head = |
| 198 | + HEAD_DIM / num_elems_per_128b<CacheT>(); |
| 199 | + constexpr uint32_t inv_kv_stride = 8 / num_vecs_per_head; |
| 200 | + |
| 201 | + extern __shared__ uint8_t smem[]; |
| 202 | + smem_t k_smem(smem); |
| 203 | + uint32_t k_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>( |
| 204 | + wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp |
| 205 | + |
| 206 | + uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>( |
| 207 | + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); |
| 208 | + |
| 209 | + uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + |
| 210 | + tid % 8 * num_elems_per_128b<CacheT>(); |
| 211 | + |
| 212 | + // load k_smem 行是64 列是128 |
| 213 | + for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行 |
| 214 | + for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16 |
| 215 | + k_smem.load_128b_async<SharedMemFillMode::kNoFill>( |
| 216 | + k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); |
| 217 | + k_smem_offset_w = |
| 218 | + k_smem.advance_offset_by_column<8, num_vecs_per_head>(k_smem_offset_w, fy); |
| 219 | + k_read_idx += 8 * num_elems_per_128b<CacheT>(); |
| 220 | + } |
| 221 | + k_smem_offset_w = |
| 222 | + k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_w) - 8; |
| 223 | + k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b<CacheT>(); |
| 224 | + } |
| 225 | + commit_group(); |
| 226 | + wait_group<0>(); |
| 227 | + __syncthreads(); |
| 228 | + |
| 229 | + // deal k_smem 行是64 列是128 |
| 230 | + for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行 |
| 231 | + uint32_t row_idx = wid * 16 + tid / 4; |
| 232 | + for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16) |
| 233 | + uint32_t col_idx = fy * 16 + tid % 4 * 2; |
| 234 | + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); |
| 235 | + // 存储 |
| 236 | + /*** |
| 237 | + r0c0,r0c1, r0c8,r0c9 |
| 238 | + r8c0,r8c1, r8c8,r8c9 |
| 239 | + ***/ |
| 240 | + T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; |
| 241 | + T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; |
| 242 | + |
| 243 | + if (row_idx < end_idx) { |
| 244 | + k_tile_ptr0[0] = frag_dq_T[0]; |
| 245 | + k_tile_ptr0[1] = frag_dq_T[1]; |
| 246 | + k_tile_ptr0[8] = frag_dq_T[4]; |
| 247 | + k_tile_ptr0[9] = frag_dq_T[5]; |
| 248 | + } |
| 249 | + |
| 250 | + if (row_idx + 8 < end_idx) { |
| 251 | + k_tile_ptr1[0] = frag_dq_T[2]; |
| 252 | + k_tile_ptr1[1] = frag_dq_T[3]; |
| 253 | + k_tile_ptr0[8] = frag_dq_T[6]; |
| 254 | + k_tile_ptr0[9] = frag_dq_T[7]; |
| 255 | + } |
| 256 | + k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>( |
| 257 | + k_smem_offset_r, fy); |
| 258 | + } |
| 259 | + k_smem_offset_r = |
| 260 | + k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_r) - 16; |
| 261 | + } |
| 262 | + // ================v================ |
| 263 | + smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); |
| 264 | + uint32_t v_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>( |
| 265 | + wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp |
| 266 | + uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head, inv_kv_stride>( |
| 267 | + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); |
| 268 | + |
| 269 | + uint32_t v_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + |
| 270 | + tid % 8 * num_elems_per_128b<CacheT>(); |
| 271 | + |
| 272 | + // load v_smem 行是64 列是128 |
| 273 | + for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行 |
| 274 | + for (int fy = 0; fy < 2; fy++) { // 一次8个128b = 64个bf16 |
| 275 | + v_smem.load_128b_async<SharedMemFillMode::kNoFill>( |
| 276 | + v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); |
| 277 | + v_smem_offset_w = |
| 278 | + v_smem.advance_offset_by_column<8, num_vecs_per_head>(v_smem_offset_w, fy); |
| 279 | + v_read_idx += 8 * num_elems_per_128b<CacheT>(); |
| 280 | + } |
| 281 | + v_smem_offset_w = |
| 282 | + v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_w) - 8; |
| 283 | + v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b<CacheT>(); |
| 284 | + } |
| 285 | + commit_group(); |
| 286 | + wait_group<0>(); |
| 287 | + __syncthreads(); |
| 288 | + |
| 289 | + // deal v_smem 行是64 列是128 |
| 290 | + |
| 291 | + for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行 |
| 292 | + uint32_t row_idx = wid * 16 + tid / 4; |
| 293 | + for (int fy = 0; fy < 8; fy++) { // 1次2个128b(16个bf16),8次循环16个128b(128个bf16) |
| 294 | + uint32_t col_idx = fy * 16 + tid % 4 * 2; |
| 295 | + v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag); |
| 296 | + // 存储 |
| 297 | + /*** |
| 298 | + r0c0,r0c1, r0c8,r0c9 |
| 299 | + r8c0,r8c1, r8c8,r8c9 |
| 300 | + ***/ |
| 301 | + T *v_tile_ptr0 = v_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; |
| 302 | + T *v_tile_ptr1 = v_tile_ptr0 + 8 * kv_t_stride; |
| 303 | + |
| 304 | + if (row_idx < end_idx) { |
| 305 | + v_tile_ptr0[0] = frag_dq_T[0]; |
| 306 | + v_tile_ptr0[1] = frag_dq_T[1]; |
| 307 | + v_tile_ptr0[8] = frag_dq_T[4]; |
| 308 | + v_tile_ptr0[9] = frag_dq_T[5]; |
| 309 | + } |
| 310 | + |
| 311 | + if (row_idx + 8 < end_idx) { |
| 312 | + v_tile_ptr1[0] = frag_dq_T[2]; |
| 313 | + v_tile_ptr1[1] = frag_dq_T[3]; |
| 314 | + v_tile_ptr0[8] = frag_dq_T[6]; |
| 315 | + v_tile_ptr0[9] = frag_dq_T[7]; |
| 316 | + } |
| 317 | + v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>( |
| 318 | + v_smem_offset_r, fy); |
| 319 | + } |
| 320 | + v_smem_offset_r = |
| 321 | + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_r) - 16; |
| 322 | + } |
| 323 | +} |
| 324 | + |
151 | 325 | template <typename T,
|
152 | 326 | typename CacheT,
|
153 | 327 | uint32_t HEAD_DIM,
|
@@ -214,7 +388,7 @@ __global__ void append_dequant_cache_kv_c8(
|
214 | 388 |
|
215 | 389 | uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
|
216 | 390 | wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
217 |
| - |
| 391 | + |
218 | 392 | uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
|
219 | 393 | tid % 8 * num_elems_per_128b<CacheT>();
|
220 | 394 |
|
@@ -328,7 +502,7 @@ __global__ void append_dequant_cache_kv_c8(
|
328 | 502 | v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
|
329 | 503 | v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
|
330 | 504 |
|
331 |
| - |
| 505 | + |
332 | 506 | convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
|
333 | 507 | #ifdef C8_DEBUG
|
334 | 508 | if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
|
@@ -371,14 +545,36 @@ void AppendDequantCache(
|
371 | 545 | paddle::Tensor *k_out,
|
372 | 546 | paddle::Tensor *v_out,
|
373 | 547 | const cudaStream_t& stream
|
374 |
| -) { |
| 548 | +) { |
375 | 549 | using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
376 |
| - if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { |
377 |
| - constexpr int NUM_WARPS = 4; |
378 |
| - int block_num = cache_num_blocks_x.data<int>()[0]; |
379 |
| - dim3 grids(block_num, 1, kv_num_heads); |
380 |
| - dim3 blocks(32, NUM_WARPS); |
381 |
| - |
| 550 | + constexpr int NUM_WARPS = 4; |
| 551 | + int block_num = cache_num_blocks_x.data<int>()[0]; |
| 552 | + dim3 grids(block_num, 1, kv_num_heads); |
| 553 | + dim3 blocks(32, NUM_WARPS); |
| 554 | + if (cache_quant_type == "none") { |
| 555 | + const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2; |
| 556 | + auto kernel_func = append_dequant_cache_kv_c16<NV_TYPE, NV_TYPE, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>; |
| 557 | + |
| 558 | + if (smem_size >= 48 * 1024) { |
| 559 | + cudaFuncSetAttribute(kernel_func, |
| 560 | + cudaFuncAttributeMaxDynamicSharedMemorySize, |
| 561 | + smem_size); |
| 562 | + } |
| 563 | + kernel_func<<<grids, blocks, smem_size, stream>>>( |
| 564 | + reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())), |
| 565 | + reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())), |
| 566 | + reinterpret_cast<NV_TYPE *>(k_out->data<T>()), |
| 567 | + reinterpret_cast<NV_TYPE *>(v_out->data<T>()), |
| 568 | + seq_lens_this_time.data<int>(), |
| 569 | + seq_lens_decoder.data<int>(), |
| 570 | + cu_seqlens_k.data<int>(), |
| 571 | + block_tables.data<int>(), |
| 572 | + cache_batch_ids.data<int>(), |
| 573 | + cache_tile_ids_per_batch.data<int>(), |
| 574 | + max_blocks_per_seq, |
| 575 | + kv_num_heads |
| 576 | + ); |
| 577 | + } else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { |
382 | 578 | const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
|
383 | 579 |
|
384 | 580 | auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
|
|
0 commit comments