Skip to content

Commit 49422df

Browse files
Haicheng Wangfacebook-github-bot
authored andcommitted
Back out "Vectorize load/store for FP8 Quantization" (pytorch#4417)
Summary: Pull Request resolved: pytorch#4417 Original commit changeset: ceb3b61319cf Original Phabricator Diff: D75563906 Related debugging doc: https://fburl.com/gdoc/vo8yhfog Reviewed By: spcyppt, basilwong, q10 Differential Revision: D77541785 fbshipit-source-id: e8ab7a54158375c386bb4715b84d27646e9fdbc2
1 parent 462a8b3 commit 49422df

File tree

4 files changed

+11
-141
lines changed

4 files changed

+11
-141
lines changed

fbgemm_gpu/FbgemmGpu.cmake

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,5 @@ gpu_cpp_library(
183183
fbgemm_gpu_tbe_cache
184184
fbgemm_gpu_tbe_optimizers
185185
fbgemm_gpu_tbe_utils
186-
fbgemm_gpu_config
187186
DESTINATION
188187
fbgemm_gpu)

fbgemm_gpu/fbgemm_gpu/config/feature_list.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ def foo():
6060
# Enable bounds_check_indices_v2
6161
BOUNDS_CHECK_INDICES_V2 = auto()
6262

63-
# Disable FP8 quantization vectorization
64-
DISABLE_FP8_QUANT_VECTORIZATION = auto()
65-
6663
def is_enabled(self) -> bool:
6764
return FeatureGate.is_enabled(self)
6865

fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ namespace fbgemm_gpu::config {
6161
X(TBE_ANNOTATE_KINETO_TRACE) \
6262
X(TBE_ROCM_INFERENCE_PACKED_BAGS) \
6363
X(TBE_ROCM_HIP_BACKWARD_KERNEL) \
64-
X(BOUNDS_CHECK_INDICES_V2) \
65-
X(DISABLE_FP8_QUANT_VECTORIZATION)
64+
X(BOUNDS_CHECK_INDICES_V2)
6665
// X(EXAMPLE_FEATURE_FLAG)
6766

6867
/// @ingroup fbgemm-gpu-config

fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu

Lines changed: 10 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
*/
88

99
#include "common.cuh"
10-
#include "fbgemm_gpu/config/feature_gates.h"
1110

1211
using Tensor = at::Tensor;
1312

@@ -118,33 +117,6 @@ __global__ inline void _get_FP8_qparam_cuda_kernel(
118117
output_row_qparams[1] = 0.0;
119118
}
120119

121-
template <typename scalar_t>
122-
struct VectorSizeTraits {
123-
// Default to 4 elements for most types (16 bytes for float)
124-
static constexpr int value = 4;
125-
};
126-
127-
// Specialization for half (float16)
128-
template <>
129-
struct VectorSizeTraits<c10::Half> {
130-
// 8 elements for half precision (16 bytes total)
131-
static constexpr int value = 8;
132-
};
133-
134-
// Specialization for __nv_bfloat16
135-
template <>
136-
struct VectorSizeTraits<c10::BFloat16> {
137-
// 8 elements for bfloat16 precision (16 bytes total)
138-
static constexpr int value = 8;
139-
};
140-
141-
// aligned vector generates vectorized load/store on CUDA (copy-pasted from
142-
// MemoryAccess.cuh)
143-
template <typename scalar_t, int vec_size = VectorSizeTraits<scalar_t>::value>
144-
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
145-
scalar_t val[vec_size];
146-
};
147-
148120
template <typename input_t>
149121
__global__ inline void _compute_FP8_quantize_cuda_kernel(
150122
const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input,
@@ -185,77 +157,6 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
185157
}
186158
}
187159

188-
template <typename input_t>
189-
__global__ inline void _compute_FP8_quantize_cuda_vectorized_kernel(
190-
const pta::PackedTensorAccessor64<input_t, 1, at::RestrictPtrTraits> input,
191-
const int64_t nrows,
192-
const int64_t ncols,
193-
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> output,
194-
const bool forward) {
195-
CUDA_KERNEL_ASSERT(nrows * ncols >= 0);
196-
197-
// Calculate global row index with 2D thread blocks
198-
const int64_t gy = blockIdx.y * blockDim.y + threadIdx.y;
199-
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
200-
static constexpr int vec_size = VectorSizeTraits<input_t>::value;
201-
// Early return if row is out of bounds
202-
if (gy >= nrows || (thread_idx * vec_size) >= ncols) {
203-
return;
204-
}
205-
206-
int ebit = forward ? 4 : 5;
207-
int bias = forward ? 15 : 31;
208-
float max_pos = forward ? 0.9375 : 0.875;
209-
210-
// Calculate output width
211-
const auto ncols_aligned = (ncols + 4 - 1) / 4 * 4;
212-
const auto output_columns = ncols_aligned + 2 * sizeof(float);
213-
214-
// Calculate base offsets for the current row
215-
const int64_t input_row_offset = gy * ncols;
216-
const int64_t output_row_offset = gy * output_columns;
217-
218-
// Calculate the position where the scale values are stored
219-
const int64_t scale_offset = output_row_offset + ncols_aligned;
220-
const float scale_value = reinterpret_cast<float*>(&output[scale_offset])[0];
221-
222-
const int64_t vector_blocks = ncols / vec_size;
223-
224-
using vec_t = aligned_vector<input_t, vec_size>;
225-
using vec_i = aligned_vector<uint8_t, vec_size>;
226-
227-
const int64_t col_idx = thread_idx * vec_size;
228-
229-
// Don't access beyond the valid input data
230-
if (col_idx + (vec_size - 1) < ncols) {
231-
// Load vec_size elements - handle both aligned and unaligned cases
232-
// correctly
233-
const vec_t* input_row =
234-
reinterpret_cast<const vec_t*>(&input[input_row_offset + col_idx]);
235-
236-
vec_i* output_row =
237-
reinterpret_cast<vec_i*>(&output[output_row_offset + col_idx]);
238-
239-
#pragma unroll
240-
for (int i = 0; i < vec_size; ++i) {
241-
output_row->val[i] = float_to_hfp8(
242-
to_float(input_row->val[i]) * scale_value, ebit, bias, max_pos);
243-
}
244-
}
245-
246-
// 2. Process any remaining elements (less than vec_size) with scalar
247-
// operations
248-
const int64_t remaining_start = vector_blocks * vec_size;
249-
for (int64_t col = remaining_start + threadIdx.x; col < ncols;
250-
col += blockDim.x) {
251-
output[output_row_offset + col] = float_to_hfp8(
252-
to_float(input[input_row_offset + col]) * scale_value,
253-
ebit,
254-
bias,
255-
max_pos);
256-
}
257-
}
258-
259160
template <typename output_t>
260161
__global__ inline void _FP8rowwise_to_float_cuda_kernel(
261162
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> input,
@@ -349,6 +250,13 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
349250
C10_CUDA_KERNEL_LAUNCH_CHECK();
350251
});
351252
} else {
253+
// range_tensor is used to store the range for each embedding row.
254+
// We save max_pos/max_val(rowwise) as row scale to quantize
255+
// unlike INT8, FP8 does not have zero shift
256+
// This will guarantee the numerical match but bring some perf
257+
// regression.
258+
auto range_tensor = at::empty({nrows}, input.options().dtype(at::kFloat));
259+
352260
{
353261
// we need a blockDim.x that is a power of 2 no larger than the warp size
354262
// of 32
@@ -386,40 +294,7 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
386294
});
387295
}
388296
389-
const uintptr_t addr = reinterpret_cast<uintptr_t>(&input);
390-
const static bool use_vectorization =
391-
((addr % 16) == 0) &&
392-
!config::is_feature_enabled(
393-
config::FeatureGateName::DISABLE_FP8_QUANT_VECTORIZATION);
394-
395-
if (use_vectorization) {
396-
const constexpr int vec_size = VectorSizeTraits<input_t>::value;
397-
const int blockDim_x = std::min(
398-
ncols > vec_size ? ncols / vec_size : 1,
399-
static_cast<int64_t>(threads_per_block));
400-
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
401-
const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x);
402-
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
403-
dim3 gridDim(gridDim_x, gridDim_y);
404-
405-
FBGEMM_DISPATCH_FLOATING_TYPES(
406-
input.scalar_type(),
407-
"_compute_FP8_quantize_cuda_vectorized_kernel",
408-
[&] {
409-
#ifdef FBGEMM_GPU_MEMCHECK
410-
const auto func_name =
411-
"_compute_FP8_quantize_cuda_vectorized_kernel";
412-
#endif
413-
_compute_FP8_quantize_cuda_vectorized_kernel<scalar_t>
414-
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
415-
MAKE_PTA_WITH_NAME(func_name, input_1D, scalar_t, 1, 64),
416-
nrows,
417-
ncols,
418-
MAKE_PTA_WITH_NAME(func_name, output_1D, uint8_t, 1, 64),
419-
forward);
420-
C10_CUDA_KERNEL_LAUNCH_CHECK();
421-
});
422-
} else {
297+
{
423298
const int blockDim_x =
424299
std::min(ncols, static_cast<int64_t>(threads_per_block));
425300
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
@@ -489,8 +364,8 @@ Tensor _FP8rowwise_to_float_gpu_t(
489364
// to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
490365
// data residing in global memory compiles to a single global memory
491366
// instruction if and only if the size of the data type is 1, 2, 4, 8, or 16
492-
// bytes and the data is naturally aligned (i.e., its address is a multiple
493-
// of that size).
367+
// bytes and the data is naturally aligned (i.e., its address is a multiple of
368+
// that size).
494369
auto output_dims = input_sizes.vec();
495370
output_dims[last_dim] = output_columns;
496371
const auto output_sdtype = static_cast<SparseType>(output_dtype);

0 commit comments

Comments
 (0)