|
7 | 7 | */
|
8 | 8 |
|
9 | 9 | #include "common.cuh"
|
| 10 | +#include "fbgemm_gpu/config/feature_gates.h" |
10 | 11 |
|
11 | 12 | using Tensor = at::Tensor;
|
12 | 13 |
|
@@ -117,6 +118,33 @@ __global__ inline void _get_FP8_qparam_cuda_kernel(
|
117 | 118 | output_row_qparams[1] = 0.0;
|
118 | 119 | }
|
119 | 120 |
|
| 121 | +template <typename scalar_t> |
| 122 | +struct VectorSizeTraits { |
| 123 | + // Default to 4 elements for most types (16 bytes for float) |
| 124 | + static constexpr int value = 4; |
| 125 | +}; |
| 126 | + |
| 127 | +// Specialization for half (float16) |
| 128 | +template <> |
| 129 | +struct VectorSizeTraits<c10::Half> { |
| 130 | + // 8 elements for half precision (16 bytes total) |
| 131 | + static constexpr int value = 8; |
| 132 | +}; |
| 133 | + |
| 134 | +// Specialization for __nv_bfloat16 |
| 135 | +template <> |
| 136 | +struct VectorSizeTraits<c10::BFloat16> { |
| 137 | + // 8 elements for bfloat16 precision (16 bytes total) |
| 138 | + static constexpr int value = 8; |
| 139 | +}; |
| 140 | + |
| 141 | +// aligned vector generates vectorized load/store on CUDA (copy-pasted from |
| 142 | +// MemoryAccess.cuh) |
| 143 | +template <typename scalar_t, int vec_size = VectorSizeTraits<scalar_t>::value> |
| 144 | +struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { |
| 145 | + scalar_t val[vec_size]; |
| 146 | +}; |
| 147 | + |
120 | 148 | template <typename input_t>
|
121 | 149 | __global__ inline void _compute_FP8_quantize_cuda_kernel(
|
122 | 150 | const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input,
|
@@ -157,6 +185,77 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
|
157 | 185 | }
|
158 | 186 | }
|
159 | 187 |
|
| 188 | +template <typename input_t> |
| 189 | +__global__ inline void _compute_FP8_quantize_cuda_vectorized_kernel( |
| 190 | + const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input, |
| 191 | + const int64_t nrows, |
| 192 | + const int64_t ncols, |
| 193 | + pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> output, |
| 194 | + const bool forward) { |
| 195 | + CUDA_KERNEL_ASSERT(nrows * ncols >= 0); |
| 196 | + |
| 197 | + // Calculate global row index with 2D thread blocks |
| 198 | + const int64_t gy = blockIdx.y * blockDim.y + threadIdx.y; |
| 199 | + const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 200 | + static constexpr int vec_size = VectorSizeTraits<input_t>::value; |
| 201 | + // Early return if row is out of bounds |
| 202 | + if (gy >= nrows || (thread_idx * vec_size) >= ncols) { |
| 203 | + return; |
| 204 | + } |
| 205 | + |
| 206 | + int ebit = forward ? 4 : 5; |
| 207 | + int bias = forward ? 15 : 31; |
| 208 | + float max_pos = forward ? 0.9375 : 0.875; |
| 209 | + |
| 210 | + // Calculate output width |
| 211 | + const auto ncols_aligned = (ncols + 4 - 1) / 4 * 4; |
| 212 | + const auto output_columns = ncols_aligned + 2 * sizeof(float); |
| 213 | + |
| 214 | + // Calculate base offsets for the current row |
| 215 | + const int64_t input_row_offset = gy * ncols; |
| 216 | + const int64_t output_row_offset = gy * output_columns; |
| 217 | + |
| 218 | + // Calculate the position where the scale values are stored |
| 219 | + const int64_t scale_offset = output_row_offset + ncols_aligned; |
| 220 | + const float scale_value = reinterpret_cast<float*>(&output[scale_offset])[0]; |
| 221 | + |
| 222 | + const int64_t vector_blocks = ncols / vec_size; |
| 223 | + |
| 224 | + using vec_t = aligned_vector<input_t, vec_size>; |
| 225 | + using vec_i = aligned_vector<uint8_t, vec_size>; |
| 226 | + |
| 227 | + const int64_t col_idx = thread_idx * vec_size; |
| 228 | + |
| 229 | + // Don't access beyond the valid input data |
| 230 | + if (col_idx + (vec_size - 1) < ncols) { |
| 231 | + // Load vec_size elements - handle both aligned and unaligned cases |
| 232 | + // correctly |
| 233 | + const vec_t* input_row = |
| 234 | + reinterpret_cast<const vec_t*>(&input[input_row_offset + col_idx]); |
| 235 | + |
| 236 | + vec_i* output_row = |
| 237 | + reinterpret_cast<vec_i*>(&output[output_row_offset + col_idx]); |
| 238 | + |
| 239 | +#pragma unroll |
| 240 | + for (int i = 0; i < vec_size; ++i) { |
| 241 | + output_row->val[i] = float_to_hfp8( |
| 242 | + to_float(input_row->val[i]) * scale_value, ebit, bias, max_pos); |
| 243 | + } |
| 244 | + } |
| 245 | + |
| 246 | + // 2. Process any remaining elements (less than vec_size) with scalar |
| 247 | + // operations |
| 248 | + const int64_t remaining_start = vector_blocks * vec_size; |
| 249 | + for (int64_t col = remaining_start + threadIdx.x; col < ncols; |
| 250 | + col += blockDim.x) { |
| 251 | + output[output_row_offset + col] = float_to_hfp8( |
| 252 | + to_float(input[input_row_offset + col]) * scale_value, |
| 253 | + ebit, |
| 254 | + bias, |
| 255 | + max_pos); |
| 256 | + } |
| 257 | +} |
| 258 | + |
160 | 259 | template <typename output_t>
|
161 | 260 | __global__ inline void _FP8rowwise_to_float_cuda_kernel(
|
162 | 261 | pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> input,
|
@@ -250,13 +349,6 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
|
250 | 349 | C10_CUDA_KERNEL_LAUNCH_CHECK();
|
251 | 350 | });
|
252 | 351 | } else {
|
253 |
| - // range_tensor is used to store the range for each embedding row. |
254 |
| - // We save max_pos/max_val(rowwise) as row scale to quantize |
255 |
| - // unlike INT8, FP8 does not have zero shift |
256 |
| - // This will guarantee the numerical match but bring some perf |
257 |
| - // regression. |
258 |
| - auto range_tensor = at::empty({nrows}, input.options().dtype(at::kFloat)); |
259 |
| -
|
260 | 352 | {
|
261 | 353 | // we need a blockDim.x that is a power of 2 no larger than the warp size
|
262 | 354 | // of 32
|
@@ -294,7 +386,40 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
|
294 | 386 | });
|
295 | 387 | }
|
296 | 388 |
|
297 |
| - { |
| 389 | + const uintptr_t addr = reinterpret_cast<uintptr_t>(&input); |
| 390 | + const static bool use_vectorization = |
| 391 | + ((addr % 16) == 0) && |
| 392 | + !config::is_feature_enabled( |
| 393 | + config::FeatureGateName::DISABLE_FP8_QUANT_VECTORIZATION); |
| 394 | +
|
| 395 | + if (use_vectorization) { |
| 396 | + const constexpr int vec_size = VectorSizeTraits<input_t>::value; |
| 397 | + const int blockDim_x = std::min( |
| 398 | + ncols > vec_size ? ncols / vec_size : 1, |
| 399 | + static_cast<int64_t>(threads_per_block)); |
| 400 | + dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); |
| 401 | + const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x); |
| 402 | + const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); |
| 403 | + dim3 gridDim(gridDim_x, gridDim_y); |
| 404 | +
|
| 405 | + FBGEMM_DISPATCH_FLOATING_TYPES( |
| 406 | + input.scalar_type(), |
| 407 | + "_compute_FP8_quantize_cuda_vectorized_kernel", |
| 408 | + [&] { |
| 409 | +#ifdef FBGEMM_GPU_MEMCHECK |
| 410 | + const auto func_name = |
| 411 | + "_compute_FP8_quantize_cuda_vectorized_kernel"; |
| 412 | +#endif |
| 413 | + _compute_FP8_quantize_cuda_vectorized_kernel<scalar_t> |
| 414 | + <<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 415 | + MAKE_PTA_WITH_NAME(func_name, input_1D, scalar_t, 1, 64), |
| 416 | + nrows, |
| 417 | + ncols, |
| 418 | + MAKE_PTA_WITH_NAME(func_name, output_1D, uint8_t, 1, 64), |
| 419 | + forward); |
| 420 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 421 | + }); |
| 422 | + } else { |
298 | 423 | const int blockDim_x =
|
299 | 424 | std::min(ncols, static_cast<int64_t>(threads_per_block));
|
300 | 425 | dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
|
@@ -364,8 +489,8 @@ Tensor _FP8rowwise_to_float_gpu_t(
|
364 | 489 | // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
|
365 | 490 | // data residing in global memory compiles to a single global memory
|
366 | 491 | // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16
|
367 |
| - // bytes and the data is naturally aligned (i.e., its address is a multiple of |
368 |
| - // that size). |
| 492 | + // bytes and the data is naturally aligned (i.e., its address is a multiple |
| 493 | + // of that size). |
369 | 494 | auto output_dims = input_sizes.vec();
|
370 | 495 | output_dims[last_dim] = output_columns;
|
371 | 496 | const auto output_sdtype = static_cast<SparseType>(output_dtype);
|
|
0 commit comments