From 393eb10fb61fbb9b6e8c5eb4b22066a4284273e9 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Wed, 9 Jul 2025 14:26:50 -0700 Subject: [PATCH] Migrate quantization kernels to `FBGEMM_LAUNCH_KERNEL`, pt 1 Summary: - Migrate quantization kernels to `FBGEMM_LAUNCH_KERNEL`, pt 1 Reviewed By: r-barnes Differential Revision: D75259308 --- fbgemm_gpu/src/quantize_ops/common.cuh | 1 + .../quantize_fused_8bit_rowwise.cu | 105 +++++++++--------- .../quantize_fused_nbit_rowwise.cu | 43 +++---- 3 files changed, 79 insertions(+), 70 deletions(-) diff --git a/fbgemm_gpu/src/quantize_ops/common.cuh b/fbgemm_gpu/src/quantize_ops/common.cuh index 49dd66cfb3..7635ad4f67 100644 --- a/fbgemm_gpu/src/quantize_ops/common.cuh +++ b/fbgemm_gpu/src/quantize_ops/common.cuh @@ -25,6 +25,7 @@ #include "fbgemm_gpu/utils/cuda_block_count.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/float.cuh" +#include "fbgemm_gpu/utils/kernel_launcher.cuh" #include "fbgemm_gpu/utils/ops_utils.h" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" #include "fbgemm_gpu/utils/tensor_utils.h" diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu index cca2bb7ed7..f2e85efb4e 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu @@ -270,16 +270,16 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) { if (nrows <= 20) { FBGEMM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "_float_to_fused8bitrowwise_cuda_kernel", [&] { - _float_to_fused8bitrowwise_cuda_kernel - <<>>( - input.data_ptr(), - nrows, - ncols, - output.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (_float_to_fused8bitrowwise_cuda_kernel), + num_blocks, + threads_per_block, + 0, + at::cuda::getCurrentCUDAStream(), + input.data_ptr(), + nrows, + ncols, + output.data_ptr()); }); } else { // range_tensor is used to store the range for each embedding row. @@ -308,17 +308,17 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) { FBGEMM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "_get_8bit_qparam_cuda_kernel", [&] { - _get_8bit_qparam_cuda_kernel - <<>>( - input.data_ptr(), - nrows, - ncols, - output.data_ptr(), - range_tensor.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (_get_8bit_qparam_cuda_kernel), + num_blocks_warp, + dim3(blockDim_x, rows_per_block), + 0, + at::cuda::getCurrentCUDAStream(), + input.data_ptr(), + nrows, + ncols, + output.data_ptr(), + range_tensor.data_ptr()); }); } @@ -331,14 +331,17 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) { FBGEMM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "_compute_8bit_quantize_cuda_kernel", [&] { - _compute_8bit_quantize_cuda_kernel - <<>>( - input.data_ptr(), - range_tensor.data_ptr(), - nrows, - ncols, - output.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (_compute_8bit_quantize_cuda_kernel), + gridDim, + blockDim, + 0, + at::cuda::getCurrentCUDAStream(), + input.data_ptr(), + range_tensor.data_ptr(), + nrows, + ncols, + output.data_ptr()); }); } } @@ -448,16 +451,20 @@ Tensor _fused8bitrowwise_to_float_gpu_t( const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); const dim3 gridDim(gridDim_x, gridDim_y); -#define DEQUANT_LAUNCH(scale_bias_last, quant_padding_float_type) \ - _fused8bitrowwise_to_float_cuda_kernel< \ - scalar_t, \ - scale_bias_last, \ - quant_padding_float_type> \ - <<>>( \ - input.data_ptr(), \ - nrows, \ - ncols, \ - output.data_ptr()) +#define DEQUANT_LAUNCH(scale_bias_last, quant_padding_float_type) \ + FBGEMM_LAUNCH_KERNEL( \ + (_fused8bitrowwise_to_float_cuda_kernel< \ + scalar_t, \ + scale_bias_last, \ + quant_padding_float_type>), \ + gridDim, \ + blockDim, \ + 0, \ + at::cuda::getCurrentCUDAStream(), \ + input.data_ptr(), \ + nrows, \ + ncols, \ + output.data_ptr()) FBGEMM_DISPATCH_FLOATING_TYPES( output.scalar_type(), "fused8bitrowwise_to_float_cuda_kernel", [&] { @@ -474,7 +481,6 @@ Tensor _fused8bitrowwise_to_float_gpu_t( DEQUANT_LAUNCH(false, false); } } - C10_CUDA_KERNEL_LAUNCH_CHECK(); }); #undef DEQUANT_LAUNCH return output; @@ -600,16 +606,15 @@ DLL_PUBLIC at::Tensor _fused8bitrowwise_to_float_mixed_dim_gpu( output.scalar_type(), "_fused8bitrowwise_to_float_mixed_dim_cuda_kernel", [&] { -#ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "_fused8bitrowwise_to_float_mixed_dim_cuda_kernel"; -#endif - _fused8bitrowwise_to_float_mixed_dim_cuda_kernel - <<>>( - MAKE_PTA_WITH_NAME(func_name, input, uint8_t, 2, 32), - MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, output, scalar_t, 2, 32)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (_fused8bitrowwise_to_float_mixed_dim_cuda_kernel), + gridDim, + blockDim, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(input, uint8_t, 2, 32), + PTA_B(D_offsets, int32_t, 1, 32), + PTA_B(output, scalar_t, 2, 32)); }); return output; } diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu index 8791060933..456cadb923 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu @@ -149,17 +149,17 @@ Tensor _float_to_fusednbitrowwise_gpu_t( FBGEMM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "_float_to_fusednbitrowwise_cuda_kernel", [&] { - _float_to_fusednbitrowwise_cuda_kernel - <<>>( - bit_rate, - input.data_ptr(), - nrows, - ncols, - output.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (_float_to_fusednbitrowwise_cuda_kernel), + num_blocks, + threads_per_block, + 0, + at::cuda::getCurrentCUDAStream(), + bit_rate, + input.data_ptr(), + nrows, + ncols, + output.data_ptr()); }); return output; @@ -267,14 +267,18 @@ Tensor _fusednbitrowwise_to_float_gpu_t( const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); const dim3 gridDim(gridDim_x, gridDim_y); -#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \ - _fusednbitrowwise_to_float_cuda_kernel \ - <<>>( \ - bit_rate, \ - input.data_ptr(), \ - nrows, \ - ncols, \ - output.data_ptr()) +#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \ + FBGEMM_LAUNCH_KERNEL( \ + (_fusednbitrowwise_to_float_cuda_kernel), \ + gridDim, \ + blockDim, \ + 0, \ + at::cuda::getCurrentCUDAStream(), \ + bit_rate, \ + input.data_ptr(), \ + nrows, \ + ncols, \ + output.data_ptr()) FBGEMM_DISPATCH_FLOATING_TYPES( output.scalar_type(), "fusednbitrowwise_to_float_cuda_kernel", [&] { @@ -283,7 +287,6 @@ Tensor _fusednbitrowwise_to_float_gpu_t( } else { DEQUANT_LAUNCH_NBIT(false); } - C10_CUDA_KERNEL_LAUNCH_CHECK(); }); #undef DEQUANT_LAUNCH_NBIT return output;