Skip to content

Commit db28973

Browse files
flaviotruzzifacebook-github-bot
authored andcommitted
Vectorize load/store for FP8 Quantization (#4262)
Summary: Pull Request resolved: #4262 X-link: facebookresearch/FBGEMM#1340 Vectorizing memory access for load and store on FloatToFP8 quantization, it tries to fully use the 16 bytes we can move at once according to the dtype of the tensor. Reviewed By: spcyppt, jwfromm, q10 Differential Revision: D75563906 fbshipit-source-id: ceb3b61319cf2a989c3d51b89570e7a265c6013a
1 parent 9bd0892 commit db28973

File tree

4 files changed

+141
-11
lines changed

4 files changed

+141
-11
lines changed

fbgemm_gpu/FbgemmGpu.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,5 +183,6 @@ gpu_cpp_library(
183183
fbgemm_gpu_tbe_cache
184184
fbgemm_gpu_tbe_optimizers
185185
fbgemm_gpu_tbe_utils
186+
fbgemm_gpu_config
186187
DESTINATION
187188
fbgemm_gpu)

fbgemm_gpu/fbgemm_gpu/config/feature_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def foo():
6060
# Enable bounds_check_indices_v2
6161
BOUNDS_CHECK_INDICES_V2 = auto()
6262

63+
# Disable FP8 quantization vectorization
64+
DISABLE_FP8_QUANT_VECTORIZATION = auto()
65+
6366
def is_enabled(self) -> bool:
6467
return FeatureGate.is_enabled(self)
6568

fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ namespace fbgemm_gpu::config {
6161
X(TBE_ANNOTATE_KINETO_TRACE) \
6262
X(TBE_ROCM_INFERENCE_PACKED_BAGS) \
6363
X(TBE_ROCM_HIP_BACKWARD_KERNEL) \
64-
X(BOUNDS_CHECK_INDICES_V2)
64+
X(BOUNDS_CHECK_INDICES_V2) \
65+
X(DISABLE_FP8_QUANT_VECTORIZATION)
6566
// X(EXAMPLE_FEATURE_FLAG)
6667

6768
/// @ingroup fbgemm-gpu-config

fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu

Lines changed: 135 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include "common.cuh"
10+
#include "fbgemm_gpu/config/feature_gates.h"
1011

1112
using Tensor = at::Tensor;
1213

@@ -117,6 +118,33 @@ __global__ inline void _get_FP8_qparam_cuda_kernel(
117118
output_row_qparams[1] = 0.0;
118119
}
119120

121+
template <typename scalar_t>
122+
struct VectorSizeTraits {
123+
// Default to 4 elements for most types (16 bytes for float)
124+
static constexpr int value = 4;
125+
};
126+
127+
// Specialization for half (float16)
128+
template <>
129+
struct VectorSizeTraits<c10::Half> {
130+
// 8 elements for half precision (16 bytes total)
131+
static constexpr int value = 8;
132+
};
133+
134+
// Specialization for __nv_bfloat16
135+
template <>
136+
struct VectorSizeTraits<c10::BFloat16> {
137+
// 8 elements for bfloat16 precision (16 bytes total)
138+
static constexpr int value = 8;
139+
};
140+
141+
// aligned vector generates vectorized load/store on CUDA (copy-pasted from
142+
// MemoryAccess.cuh)
143+
template <typename scalar_t, int vec_size = VectorSizeTraits<scalar_t>::value>
144+
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
145+
scalar_t val[vec_size];
146+
};
147+
120148
template <typename input_t>
121149
__global__ inline void _compute_FP8_quantize_cuda_kernel(
122150
const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input,
@@ -157,6 +185,77 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
157185
}
158186
}
159187

188+
template <typename input_t>
189+
__global__ inline void _compute_FP8_quantize_cuda_vectorized_kernel(
190+
const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input,
191+
const int64_t nrows,
192+
const int64_t ncols,
193+
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> output,
194+
const bool forward) {
195+
CUDA_KERNEL_ASSERT(nrows * ncols >= 0);
196+
197+
// Calculate global row index with 2D thread blocks
198+
const int64_t gy = blockIdx.y * blockDim.y + threadIdx.y;
199+
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
200+
static constexpr int vec_size = VectorSizeTraits<input_t>::value;
201+
// Early return if row is out of bounds
202+
if (gy >= nrows || (thread_idx * vec_size) >= ncols) {
203+
return;
204+
}
205+
206+
int ebit = forward ? 4 : 5;
207+
int bias = forward ? 15 : 31;
208+
float max_pos = forward ? 0.9375 : 0.875;
209+
210+
// Calculate output width
211+
const auto ncols_aligned = (ncols + 4 - 1) / 4 * 4;
212+
const auto output_columns = ncols_aligned + 2 * sizeof(float);
213+
214+
// Calculate base offsets for the current row
215+
const int64_t input_row_offset = gy * ncols;
216+
const int64_t output_row_offset = gy * output_columns;
217+
218+
// Calculate the position where the scale values are stored
219+
const int64_t scale_offset = output_row_offset + ncols_aligned;
220+
const float scale_value = reinterpret_cast<float*>(&output[scale_offset])[0];
221+
222+
const int64_t vector_blocks = ncols / vec_size;
223+
224+
using vec_t = aligned_vector<input_t, vec_size>;
225+
using vec_i = aligned_vector<uint8_t, vec_size>;
226+
227+
const int64_t col_idx = thread_idx * vec_size;
228+
229+
// Don't access beyond the valid input data
230+
if (col_idx + (vec_size - 1) < ncols) {
231+
// Load vec_size elements - handle both aligned and unaligned cases
232+
// correctly
233+
const vec_t* input_row =
234+
reinterpret_cast<const vec_t*>(&input[input_row_offset + col_idx]);
235+
236+
vec_i* output_row =
237+
reinterpret_cast<vec_i*>(&output[output_row_offset + col_idx]);
238+
239+
#pragma unroll
240+
for (int i = 0; i < vec_size; ++i) {
241+
output_row->val[i] = float_to_hfp8(
242+
to_float(input_row->val[i]) * scale_value, ebit, bias, max_pos);
243+
}
244+
}
245+
246+
// 2. Process any remaining elements (less than vec_size) with scalar
247+
// operations
248+
const int64_t remaining_start = vector_blocks * vec_size;
249+
for (int64_t col = remaining_start + threadIdx.x; col < ncols;
250+
col += blockDim.x) {
251+
output[output_row_offset + col] = float_to_hfp8(
252+
to_float(input[input_row_offset + col]) * scale_value,
253+
ebit,
254+
bias,
255+
max_pos);
256+
}
257+
}
258+
160259
template <typename output_t>
161260
__global__ inline void _FP8rowwise_to_float_cuda_kernel(
162261
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> input,
@@ -250,13 +349,6 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
250349
C10_CUDA_KERNEL_LAUNCH_CHECK();
251350
});
252351
} else {
253-
// range_tensor is used to store the range for each embedding row.
254-
// We save max_pos/max_val(rowwise) as row scale to quantize
255-
// unlike INT8, FP8 does not have zero shift
256-
// This will guarantee the numerical match but bring some perf
257-
// regression.
258-
auto range_tensor = at::empty({nrows}, input.options().dtype(at::kFloat));
259-
260352
{
261353
// we need a blockDim.x that is a power of 2 no larger than the warp size
262354
// of 32
@@ -294,7 +386,40 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
294386
});
295387
}
296388
297-
{
389+
const uintptr_t addr = reinterpret_cast<uintptr_t>(&input);
390+
const static bool use_vectorization =
391+
((addr % 16) == 0) &&
392+
!config::is_feature_enabled(
393+
config::FeatureGateName::DISABLE_FP8_QUANT_VECTORIZATION);
394+
395+
if (use_vectorization) {
396+
const constexpr int vec_size = VectorSizeTraits<input_t>::value;
397+
const int blockDim_x = std::min(
398+
ncols > vec_size ? ncols / vec_size : 1,
399+
static_cast<int64_t>(threads_per_block));
400+
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
401+
const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x);
402+
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
403+
dim3 gridDim(gridDim_x, gridDim_y);
404+
405+
FBGEMM_DISPATCH_FLOATING_TYPES(
406+
input.scalar_type(),
407+
"_compute_FP8_quantize_cuda_vectorized_kernel",
408+
[&] {
409+
#ifdef FBGEMM_GPU_MEMCHECK
410+
const auto func_name =
411+
"_compute_FP8_quantize_cuda_vectorized_kernel";
412+
#endif
413+
_compute_FP8_quantize_cuda_vectorized_kernel<scalar_t>
414+
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
415+
MAKE_PTA_WITH_NAME(func_name, input_1D, scalar_t, 1, 64),
416+
nrows,
417+
ncols,
418+
MAKE_PTA_WITH_NAME(func_name, output_1D, uint8_t, 1, 64),
419+
forward);
420+
C10_CUDA_KERNEL_LAUNCH_CHECK();
421+
});
422+
} else {
298423
const int blockDim_x =
299424
std::min(ncols, static_cast<int64_t>(threads_per_block));
300425
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
@@ -364,8 +489,8 @@ Tensor _FP8rowwise_to_float_gpu_t(
364489
// to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
365490
// data residing in global memory compiles to a single global memory
366491
// instruction if and only if the size of the data type is 1, 2, 4, 8, or 16
367-
// bytes and the data is naturally aligned (i.e., its address is a multiple of
368-
// that size).
492+
// bytes and the data is naturally aligned (i.e., its address is a multiple
493+
// of that size).
369494
auto output_dims = input_sizes.vec();
370495
output_dims[last_dim] = output_columns;
371496
const auto output_sdtype = static_cast<SparseType>(output_dtype);

0 commit comments

Comments
 (0)