@@ -279,15 +279,16 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
279
279
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
280
280
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
281
281
typename DType, typename IdType>
282
- __global__ void TopKTopPSamplingFromProbKernel (DType* probs, IdType* output, float * top_p_arr,
282
+ __global__ void TopKTopPSamplingFromProbKernel (DType* probs, IdType* output,
283
+ float * top_p_arr, IdType* top_k_arr,
283
284
uint32_t d, uint64_t philox_seed,
284
285
uint64_t philox_offset) {
285
286
const uint32_t batch_size = gridDim .x ;
286
287
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
287
288
curandStatePhilox4_32_10_t state;
288
289
curand_init (philox_seed, bx, philox_offset, &state);
289
290
const uint32_t row_idx = bx;
290
- const uint32_t k = top_p_arr [row_idx] == 0 ? 1 : 20 ;
291
+ const uint32_t k = top_k_arr [row_idx] == 0 ? d : top_k_arr[row_idx] ;
291
292
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
292
293
293
294
extern __shared__ __align__ (
@@ -479,7 +480,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
479
480
if (aggregate_gt_pivot_0 < top_p) {
480
481
// case 1: pivot_0 accepted
481
482
break ;
482
- }
483
+ }
483
484
if (aggregate_gt_pivot_1 < top_p) {
484
485
// case 2: pivot_0 rejected, pivot_1 accepted
485
486
low = pivot_0;
@@ -497,6 +498,183 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
497
498
}
498
499
}
499
500
501
+ template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
502
+ typename TempStorage>
503
+ __device__ __forceinline__ float GetMaxValue (float * in_data, uint32_t row_idx, uint32_t d,
504
+ TempStorage& temp_storage) {
505
+ const uint32_t tx = threadIdx .x ;
506
+ vec_t <float , VEC_SIZE> in_data_vec;
507
+
508
+ float max_val = 0 ;
509
+ for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
510
+ in_data_vec.fill (0 );
511
+ if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
512
+ in_data_vec.cast_load (in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
513
+ }
514
+ float in_data_[VEC_SIZE];
515
+ #pragma unroll
516
+ for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
517
+ in_data_[j] = in_data_vec[j];
518
+ }
519
+ max_val = max (
520
+ max_val, BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
521
+ .Reduce <VEC_SIZE>(in_data_, cub::Max ()));
522
+ __syncthreads ();
523
+ }
524
+ if (tx == 0 ) {
525
+ temp_storage.max_val = max_val;
526
+ }
527
+ __syncthreads ();
528
+ return temp_storage.max_val ;
529
+ }
530
+
531
+ template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
532
+ struct RenormTempStorage {
533
+ union {
534
+ typename BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
535
+ typename BlockReduce<int , BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
536
+ typename BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
537
+ reduce_value_count;
538
+ } block_prim;
539
+ struct {
540
+ float max_val;
541
+ float min_val;
542
+ union {
543
+ struct {
544
+ float values[2 ];
545
+ };
546
+ struct {
547
+ int counts[2 ];
548
+ };
549
+ struct {
550
+ ValueCount<float > pairs[2 ];
551
+ };
552
+ } block_aggregate;
553
+ };
554
+ };
555
+
556
+ template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
557
+ typename DType, typename IdType>
558
+ __global__ void TopKRenormProbKernel (DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
559
+ const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
560
+ const uint32_t row_idx = bx;
561
+ const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
562
+ double pivot = -cuda::std::numeric_limits<float >::infinity (), normalizer = 1 ;
563
+ vec_t <float , VEC_SIZE> probs_vec;
564
+ if (k < d) {
565
+ extern __shared__ __align__ (alignof (RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
566
+ uint8_t smem_renorm[];
567
+ auto & temp_storage =
568
+ reinterpret_cast <RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
569
+ temp_storage.max_val = 0 ;
570
+
571
+ float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
572
+ RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
573
+ probs, row_idx, d, temp_storage);
574
+
575
+ double low = 0 , high = max_val;
576
+ float min_gt_low, max_le_high;
577
+ float sum_low = 1 ;
578
+ // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
579
+ // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
580
+ // loop invariant:
581
+ // - f(low) >= k, f(high) < k
582
+ // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
583
+ // stopping condition: min_gt_low == max_le_high
584
+ // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
585
+ do {
586
+ double pivot_0 = (high + 2 * low) / 3 ;
587
+ double pivot_1 = (2 * high + low) / 3 ;
588
+
589
+ ValueCount<float > aggregate_gt_pivot_0{0 , 0 }, aggregate_gt_pivot_1{0 , 0 };
590
+ min_gt_low = high;
591
+ max_le_high = low;
592
+ #pragma unroll 2
593
+ for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
594
+ probs_vec.fill (0 );
595
+ if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
596
+ probs_vec.cast_load (probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
597
+ }
598
+ ValueCount<float > probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
599
+ #pragma unroll
600
+ for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
601
+ probs_gt_pivot_0_pair[j] = {
602
+ (probs_vec[j] > pivot_0) ? probs_vec[j] : 0 ,
603
+ (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
604
+ probs_gt_pivot_1_pair[j] = {
605
+ (probs_vec[j] > pivot_1) ? probs_vec[j] : 0 ,
606
+ (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
607
+
608
+ if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
609
+ min_gt_low = min (min_gt_low, probs_vec[j]);
610
+ }
611
+ if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
612
+ max_le_high = max (max_le_high, probs_vec[j]);
613
+ }
614
+ }
615
+
616
+ aggregate_gt_pivot_0 += BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
617
+ temp_storage.block_prim .reduce_value_count )
618
+ .Sum <VEC_SIZE>(probs_gt_pivot_0_pair);
619
+ __syncthreads ();
620
+
621
+ aggregate_gt_pivot_1 += BlockReduce<ValueCount<float >, BLOCK_THREADS, REDUCE_ALGORITHM>(
622
+ temp_storage.block_prim .reduce_value_count )
623
+ .Sum <VEC_SIZE>(probs_gt_pivot_1_pair);
624
+ __syncthreads ();
625
+ }
626
+ min_gt_low =
627
+ BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
628
+ .Reduce (min_gt_low, cub::Min ());
629
+ __syncthreads ();
630
+ max_le_high =
631
+ BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
632
+ .Reduce (max_le_high, cub::Max ());
633
+ if (tx == 0 ) {
634
+ temp_storage.block_aggregate .pairs [0 ] = aggregate_gt_pivot_0;
635
+ temp_storage.block_aggregate .pairs [1 ] = aggregate_gt_pivot_1;
636
+ temp_storage.min_val = min_gt_low;
637
+ temp_storage.max_val = max_le_high;
638
+ }
639
+ __syncthreads ();
640
+ aggregate_gt_pivot_0 = temp_storage.block_aggregate .pairs [0 ];
641
+ aggregate_gt_pivot_1 = temp_storage.block_aggregate .pairs [1 ];
642
+ min_gt_low = temp_storage.min_val ;
643
+ max_le_high = temp_storage.max_val ;
644
+
645
+ if (aggregate_gt_pivot_1.count >= k) {
646
+ low = pivot_1;
647
+ sum_low = float (aggregate_gt_pivot_1.value );
648
+ } else if (aggregate_gt_pivot_0.count >= k) {
649
+ low = pivot_0;
650
+ high = min (pivot_1, max_le_high);
651
+ sum_low = float (aggregate_gt_pivot_0.value );
652
+ } else {
653
+ high = min (pivot_0, max_le_high);
654
+ }
655
+ } while (min_gt_low != max_le_high);
656
+
657
+ normalizer = ptx_rcp (max (sum_low, 1e-8 ));
658
+ pivot = low;
659
+ }
660
+
661
+ // normalize
662
+ #pragma unroll 2
663
+ for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
664
+ probs_vec.fill (0 );
665
+ if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
666
+ probs_vec.cast_load (probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
667
+ }
668
+ #pragma unroll
669
+ for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
670
+ probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0 ;
671
+ }
672
+ if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
673
+ probs_vec.store (renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
674
+ }
675
+ }
676
+ }
677
+
500
678
template <typename T, typename IdType>
501
679
cudaError_t TopPSamplingFromProb (T *probs, IdType *output,
502
680
uint32_t batch_size, const T *top_p_val,
@@ -529,7 +707,7 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
529
707
530
708
template <typename T, typename IdType>
531
709
cudaError_t TopKTopPSamplingFromProb (T *probs, IdType *output,
532
- uint32_t batch_size, const T *top_p_val,
710
+ uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
533
711
uint32_t d, bool deterministic,
534
712
uint64_t philox_seed, uint64_t philox_offset,
535
713
cudaStream_t stream = 0 ) {
@@ -540,7 +718,7 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
540
718
const uint32_t smem_size = sizeof (SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
541
719
dim3 nblks (batch_size);
542
720
dim3 nthrs (BLOCK_THREADS);
543
- void * args[] = {&probs, &output, &top_p_val,
721
+ void * args[] = {&probs, &output, &top_p_val, &top_k_val,
544
722
&d, &philox_seed, &philox_offset};
545
723
546
724
DISPATCH_ALIGNED_VEC_SIZE (
@@ -556,4 +734,26 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
556
734
});
557
735
}
558
736
559
- } // namespace sampling
737
+ template <typename DType, typename IdType>
738
+ cudaError_t TopKRenormProb (DType* probs, DType* renormed_prob, IdType* top_k_arr,
739
+ uint32_t batch_size, uint32_t d,
740
+ cudaStream_t stream = 0 ) {
741
+ const uint32_t vec_size = std::gcd (16 / sizeof (DType), d);
742
+
743
+ auto compute_capacity = GetCudaComputeCapability ();
744
+ DISPATCH_COMPUTE_CAP_NUM_THREADS (compute_capacity, BLOCK_THREADS, {
745
+ const uint32_t smem_size = sizeof (RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
746
+ dim3 nblks (batch_size);
747
+ dim3 nthrs (BLOCK_THREADS);
748
+ void * args[] = {&probs, &renormed_prob, &top_k_arr, &d};
749
+ DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
750
+ auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
751
+ CUDA_CALL (
752
+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
753
+ CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
754
+ });
755
+ return cudaSuccess;
756
+ });
757
+ }
758
+
759
+ } // namespace sampling
0 commit comments