Skip to content

Commit e477ccf

Browse files
committed
fix prefix_cache in fa3
1 parent 1b1ecb3 commit e477ccf

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ template <typename T,
152152
uint32_t HEAD_DIM,
153153
uint32_t BLOCK_SIZE,
154154
uint32_t NUM_WARPS=4>
155-
__global__ void append_dequant_cache_kv_c16(
155+
__global__ void append_cache_kv_c16(
156156
const T *__restrict__ cache_k,
157157
const T *__restrict__ cache_v,
158158
T *__restrict__ k_out,
@@ -174,7 +174,7 @@ __global__ void append_dequant_cache_kv_c16(
174174
const uint32_t batch_id = batch_ids[tile_idx];
175175
const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE;
176176
const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx;
177-
if (seq_lens_this_time <= 0) {
177+
if (seq_lens_this_time[batch_id] <= 0) {
178178
return;
179179
}
180180

@@ -250,8 +250,8 @@ __global__ void append_dequant_cache_kv_c16(
250250
if (row_idx + 8 < end_idx) {
251251
k_tile_ptr1[0] = frag_dq_T[2];
252252
k_tile_ptr1[1] = frag_dq_T[3];
253-
k_tile_ptr0[8] = frag_dq_T[6];
254-
k_tile_ptr0[9] = frag_dq_T[7];
253+
k_tile_ptr1[8] = frag_dq_T[6];
254+
k_tile_ptr1[9] = frag_dq_T[7];
255255
}
256256
k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>(
257257
k_smem_offset_r, fy);
@@ -311,8 +311,8 @@ __global__ void append_dequant_cache_kv_c16(
311311
if (row_idx + 8 < end_idx) {
312312
v_tile_ptr1[0] = frag_dq_T[2];
313313
v_tile_ptr1[1] = frag_dq_T[3];
314-
v_tile_ptr0[8] = frag_dq_T[6];
315-
v_tile_ptr0[9] = frag_dq_T[7];
314+
v_tile_ptr1[8] = frag_dq_T[6];
315+
v_tile_ptr1[9] = frag_dq_T[7];
316316
}
317317
v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>(
318318
v_smem_offset_r, fy);
@@ -328,7 +328,7 @@ template <typename T,
328328
uint32_t BLOCK_SIZE,
329329
uint32_t NUM_WARPS=4,
330330
bool IS_FP8=false>
331-
__global__ void append_dequant_cache_kv_c8(
331+
__global__ void append_cache_kv_c8(
332332
const CacheT *__restrict__ cache_k,
333333
const CacheT *__restrict__ cache_v,
334334
T *__restrict__ k_out,
@@ -527,7 +527,7 @@ __global__ void append_dequant_cache_kv_c8(
527527
}
528528

529529
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
530-
void AppendDequantCache(
530+
void AppendCacheKV(
531531
const paddle::Tensor &cache_k,
532532
const paddle::Tensor &cache_v,
533533
const paddle::Tensor &cache_k_dequant_scales,
@@ -553,7 +553,7 @@ void AppendDequantCache(
553553
dim3 blocks(32, NUM_WARPS);
554554
if (cache_quant_type == "none") {
555555
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2;
556-
auto kernel_func = append_dequant_cache_kv_c16<NV_TYPE, NV_TYPE, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>;
556+
auto kernel_func = append_cache_kv_c16<NV_TYPE, NV_TYPE, HEAD_DIM, BLOCK_SIZE, NUM_WARPS>;
557557

558558
if (smem_size >= 48 * 1024) {
559559
cudaFuncSetAttribute(kernel_func,
@@ -577,9 +577,9 @@ void AppendDequantCache(
577577
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
578578
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
579579

580-
auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
580+
auto kernel_func = append_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
581581
if (cache_quant_type == "cache_fp8") {
582-
kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, true>;
582+
kernel_func = append_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, true>;
583583
}
584584
if (smem_size >= 48 * 1024) {
585585
cudaFuncSetAttribute(kernel_func,
@@ -757,7 +757,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
757757
}
758758

759759
if (token_num < kv_token_num) {
760-
AppendDequantCache<data_t, 128, 64>(
760+
AppendCacheKV<data_t, 128, 64>(
761761
key_cache,
762762
value_cache,
763763
cache_k_dequant_scales.get(),

0 commit comments

Comments
 (0)