@@ -152,7 +152,7 @@ template <typename T,
152
152
uint32_t HEAD_DIM,
153
153
uint32_t BLOCK_SIZE,
154
154
uint32_t NUM_WARPS=4 >
155
- __global__ void append_dequant_cache_kv_c16 (
155
+ __global__ void append_cache_kv_c16 (
156
156
const T *__restrict__ cache_k,
157
157
const T *__restrict__ cache_v,
158
158
T *__restrict__ k_out,
@@ -174,7 +174,7 @@ __global__ void append_dequant_cache_kv_c16(
174
174
const uint32_t batch_id = batch_ids[tile_idx];
175
175
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
176
176
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
177
- if (seq_lens_this_time <= 0 ) {
177
+ if (seq_lens_this_time[batch_id] <= 0 ) {
178
178
return ;
179
179
}
180
180
@@ -250,8 +250,8 @@ __global__ void append_dequant_cache_kv_c16(
250
250
if (row_idx + 8 < end_idx) {
251
251
k_tile_ptr1[0 ] = frag_dq_T[2 ];
252
252
k_tile_ptr1[1 ] = frag_dq_T[3 ];
253
- k_tile_ptr0 [8 ] = frag_dq_T[6 ];
254
- k_tile_ptr0 [9 ] = frag_dq_T[7 ];
253
+ k_tile_ptr1 [8 ] = frag_dq_T[6 ];
254
+ k_tile_ptr1 [9 ] = frag_dq_T[7 ];
255
255
}
256
256
k_smem_offset_r = k_smem.advance_offset_by_column <2 , num_vecs_per_head>(
257
257
k_smem_offset_r, fy);
@@ -311,8 +311,8 @@ __global__ void append_dequant_cache_kv_c16(
311
311
if (row_idx + 8 < end_idx) {
312
312
v_tile_ptr1[0 ] = frag_dq_T[2 ];
313
313
v_tile_ptr1[1 ] = frag_dq_T[3 ];
314
- v_tile_ptr0 [8 ] = frag_dq_T[6 ];
315
- v_tile_ptr0 [9 ] = frag_dq_T[7 ];
314
+ v_tile_ptr1 [8 ] = frag_dq_T[6 ];
315
+ v_tile_ptr1 [9 ] = frag_dq_T[7 ];
316
316
}
317
317
v_smem_offset_r = v_smem.advance_offset_by_column <2 , num_vecs_per_head>(
318
318
v_smem_offset_r, fy);
@@ -328,7 +328,7 @@ template <typename T,
328
328
uint32_t BLOCK_SIZE,
329
329
uint32_t NUM_WARPS=4 ,
330
330
bool IS_FP8=false >
331
- __global__ void append_dequant_cache_kv_c8 (
331
+ __global__ void append_cache_kv_c8 (
332
332
const CacheT *__restrict__ cache_k,
333
333
const CacheT *__restrict__ cache_v,
334
334
T *__restrict__ k_out,
@@ -527,7 +527,7 @@ __global__ void append_dequant_cache_kv_c8(
527
527
}
528
528
529
529
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
530
- void AppendDequantCache (
530
+ void AppendCacheKV (
531
531
const paddle::Tensor &cache_k,
532
532
const paddle::Tensor &cache_v,
533
533
const paddle::Tensor &cache_k_dequant_scales,
@@ -553,7 +553,7 @@ void AppendDequantCache(
553
553
dim3 blocks (32 , NUM_WARPS);
554
554
if (cache_quant_type == " none" ) {
555
555
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof (T) * 2 ;
556
- auto kernel_func = append_dequant_cache_kv_c16 <NV_TYPE, NV_TYPE, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>;
556
+ auto kernel_func = append_cache_kv_c16 <NV_TYPE, NV_TYPE, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>;
557
557
558
558
if (smem_size >= 48 * 1024 ) {
559
559
cudaFuncSetAttribute (kernel_func,
@@ -577,9 +577,9 @@ void AppendDequantCache(
577
577
} else if (cache_quant_type == " cache_int8" || cache_quant_type == " cache_fp8" ) {
578
578
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof (uint8_t ) * 2 ;
579
579
580
- auto kernel_func = append_dequant_cache_kv_c8 <NV_TYPE, uint8_t , HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false >;
580
+ auto kernel_func = append_cache_kv_c8 <NV_TYPE, uint8_t , HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false >;
581
581
if (cache_quant_type == " cache_fp8" ) {
582
- kernel_func = append_dequant_cache_kv_c8 <NV_TYPE, uint8_t , HEAD_DIM, BLOCK_SIZE, NUM_WARPS, true >;
582
+ kernel_func = append_cache_kv_c8 <NV_TYPE, uint8_t , HEAD_DIM, BLOCK_SIZE, NUM_WARPS, true >;
583
583
}
584
584
if (smem_size >= 48 * 1024 ) {
585
585
cudaFuncSetAttribute (kernel_func,
@@ -757,7 +757,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
757
757
}
758
758
759
759
if (token_num < kv_token_num) {
760
- AppendDequantCache <data_t , 128 , 64 >(
760
+ AppendCacheKV <data_t , 128 , 64 >(
761
761
key_cache,
762
762
value_cache,
763
763
cache_k_dequant_scales.get (),
0 commit comments