diff --git a/custom_ops/gpu_ops/sample_kernels/sampling.cuh b/custom_ops/gpu_ops/sample_kernels/sampling.cuh index eb5f6f1b84..c4764c00c5 100644 --- a/custom_ops/gpu_ops/sample_kernels/sampling.cuh +++ b/custom_ops/gpu_ops/sample_kernels/sampling.cuh @@ -287,7 +287,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, flo curandStatePhilox4_32_10_t state; curand_init(philox_seed, bx, philox_offset, &state); const uint32_t row_idx = bx; - const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 20; + const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 40; const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx]; extern __shared__ __align__(