Skip to content

Migrate quantization kernels to FBGEMM_LAUNCH_KERNEL, pt 1 #4465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fbgemm_gpu/src/quantize_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "fbgemm_gpu/utils/cuda_block_count.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/float.cuh"
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
#include "fbgemm_gpu/utils/ops_utils.h"
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
#include "fbgemm_gpu/utils/tensor_utils.h"
Expand Down
105 changes: 55 additions & 50 deletions fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,16 +270,16 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) {
if (nrows <= 20) {
FBGEMM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "_float_to_fused8bitrowwise_cuda_kernel", [&] {
_float_to_fused8bitrowwise_cuda_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(),
nrows,
ncols,
output.data_ptr<std::uint8_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(_float_to_fused8bitrowwise_cuda_kernel<scalar_t>),
num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream(),
input.data_ptr<scalar_t>(),
nrows,
ncols,
output.data_ptr<std::uint8_t>());
});
} else {
// range_tensor is used to store the range for each embedding row.
Expand Down Expand Up @@ -308,17 +308,17 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) {

FBGEMM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "_get_8bit_qparam_cuda_kernel", [&] {
_get_8bit_qparam_cuda_kernel<scalar_t>
<<<num_blocks_warp,
dim3(blockDim_x, rows_per_block),
0,
at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(),
nrows,
ncols,
output.data_ptr<std::uint8_t>(),
range_tensor.data_ptr<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(_get_8bit_qparam_cuda_kernel<scalar_t>),
num_blocks_warp,
dim3(blockDim_x, rows_per_block),
0,
at::cuda::getCurrentCUDAStream(),
input.data_ptr<scalar_t>(),
nrows,
ncols,
output.data_ptr<std::uint8_t>(),
range_tensor.data_ptr<float>());
});
}

Expand All @@ -331,14 +331,17 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) {

FBGEMM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "_compute_8bit_quantize_cuda_kernel", [&] {
_compute_8bit_quantize_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(),
range_tensor.data_ptr<float>(),
nrows,
ncols,
output.data_ptr<std::uint8_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(_compute_8bit_quantize_cuda_kernel<scalar_t>),
gridDim,
blockDim,
0,
at::cuda::getCurrentCUDAStream(),
input.data_ptr<scalar_t>(),
range_tensor.data_ptr<float>(),
nrows,
ncols,
output.data_ptr<std::uint8_t>());
});
}
}
Expand Down Expand Up @@ -448,16 +451,20 @@ Tensor _fused8bitrowwise_to_float_gpu_t(
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);

#define DEQUANT_LAUNCH(scale_bias_last, quant_padding_float_type) \
_fused8bitrowwise_to_float_cuda_kernel< \
scalar_t, \
scale_bias_last, \
quant_padding_float_type> \
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>( \
input.data_ptr<std::uint8_t>(), \
nrows, \
ncols, \
output.data_ptr<scalar_t>())
#define DEQUANT_LAUNCH(scale_bias_last, quant_padding_float_type) \
FBGEMM_LAUNCH_KERNEL( \
(_fused8bitrowwise_to_float_cuda_kernel< \
scalar_t, \
scale_bias_last, \
quant_padding_float_type>), \
gridDim, \
blockDim, \
0, \
at::cuda::getCurrentCUDAStream(), \
input.data_ptr<std::uint8_t>(), \
nrows, \
ncols, \
output.data_ptr<scalar_t>())

FBGEMM_DISPATCH_FLOATING_TYPES(
output.scalar_type(), "fused8bitrowwise_to_float_cuda_kernel", [&] {
Expand All @@ -474,7 +481,6 @@ Tensor _fused8bitrowwise_to_float_gpu_t(
DEQUANT_LAUNCH(false, false);
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#undef DEQUANT_LAUNCH
return output;
Expand Down Expand Up @@ -600,16 +606,15 @@ DLL_PUBLIC at::Tensor _fused8bitrowwise_to_float_mixed_dim_gpu(
output.scalar_type(),
"_fused8bitrowwise_to_float_mixed_dim_cuda_kernel",
[&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"_fused8bitrowwise_to_float_mixed_dim_cuda_kernel";
#endif
_fused8bitrowwise_to_float_mixed_dim_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, input, uint8_t, 2, 32),
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, output, scalar_t, 2, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(_fused8bitrowwise_to_float_mixed_dim_cuda_kernel<scalar_t>),
gridDim,
blockDim,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(input, uint8_t, 2, 32),
PTA_B(D_offsets, int32_t, 1, 32),
PTA_B(output, scalar_t, 2, 32));
});
return output;
}
Expand Down
43 changes: 23 additions & 20 deletions fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,17 @@ Tensor _float_to_fusednbitrowwise_gpu_t(

FBGEMM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "_float_to_fusednbitrowwise_cuda_kernel", [&] {
_float_to_fusednbitrowwise_cuda_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
bit_rate,
input.data_ptr<scalar_t>(),
nrows,
ncols,
output.data_ptr<std::uint8_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(_float_to_fusednbitrowwise_cuda_kernel<scalar_t>),
num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream(),
bit_rate,
input.data_ptr<scalar_t>(),
nrows,
ncols,
output.data_ptr<std::uint8_t>());
});

return output;
Expand Down Expand Up @@ -267,14 +267,18 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);

#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \
_fusednbitrowwise_to_float_cuda_kernel<scalar_t, scale_bias_last> \
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>( \
bit_rate, \
input.data_ptr<std::uint8_t>(), \
nrows, \
ncols, \
output.data_ptr<scalar_t>())
#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \
FBGEMM_LAUNCH_KERNEL( \
(_fusednbitrowwise_to_float_cuda_kernel<scalar_t, scale_bias_last>), \
gridDim, \
blockDim, \
0, \
at::cuda::getCurrentCUDAStream(), \
bit_rate, \
input.data_ptr<std::uint8_t>(), \
nrows, \
ncols, \
output.data_ptr<scalar_t>())

FBGEMM_DISPATCH_FLOATING_TYPES(
output.scalar_type(), "fusednbitrowwise_to_float_cuda_kernel", [&] {
Expand All @@ -283,7 +287,6 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
} else {
DEQUANT_LAUNCH_NBIT(false);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#undef DEQUANT_LAUNCH_NBIT
return output;
Expand Down
Loading