diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index 56fbaf1c01..8e78447d5c 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -4,12 +4,12 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Tuple +from typing import Tuple import fire import torch import triton -from torch._inductor.utils import do_bench_using_profiling +from triton.testing import do_bench from torchao.prototype.mx_formats.kernels import ( triton_to_mxfp8_dim1, @@ -64,29 +64,35 @@ def to_mx_dim1_reference(x_hp, block_size): return data_d1.t(), scale_d1 -def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float: - """Thin wrapper around do_bench_using_profiling""" - no_args = lambda: func(*args, **kwargs) - time = do_bench_using_profiling(no_args) - return time * 1e3 +def benchmark_cuda_function_in_microseconds(f, *args): + return do_bench(lambda: f(*args), return_mode="median") * 1e3 def run( M: int = 16384, K: int = 16384, BLOCK_SIZE: int = 32, - mode: str = "dim0", + mode: str = "dim0_floor", ): print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}") print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"torch version: {torch.__version__}") print(f"triton version: {triton.__version__}") print(f"mode: {mode}") - assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton") + assert mode in ( + "dim0_floor", + "dim1_floor", + "dim0_dim1_floor", + "dim0_mx_floor", + "dim1_mx_floor", + "dim1_mx_triton_floor", + "dim1_mx_cuda_floor", + "dim1_mx_cuda_rceil", + ) x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000 - if mode == "dim0": + if mode == "dim0_floor": scale_dim0_reference_c = torch.compile(scale_dim0_reference) y_d0, s_d0 = scale_dim0_reference_c(x, BLOCK_SIZE) @@ -103,7 +109,7 @@ def run( bytes_rw = sum(t.numel() for t in [x, y_d0, s_d0]) * bytes_per_el_bf16 bps = bytes_rw / (time_us / 1e6) - elif mode == "dim1": + elif mode == "dim1_floor": scale_dim1_reference_c = torch.compile(scale_dim1_reference) y_d1, s_d1 = scale_dim1_reference_c(x, BLOCK_SIZE) @@ -120,7 +126,7 @@ def run( bytes_rw = sum(t.numel() for t in [x, y_d1, s_d1]) * bytes_per_el_bf16 bps = bytes_rw / (time_us / 1e6) - elif mode == "dim0_dim1": + elif mode == "dim0_dim1_floor": scale_dim0_dim1_reference_c = torch.compile(scale_dim0_dim1_reference) y_d0, y_d1, s_d0, s_d1 = scale_dim0_dim1_reference_c(x, BLOCK_SIZE) @@ -141,7 +147,7 @@ def run( ) bps = bytes_rw / (time_us / 1e6) - elif mode == "dim0_mx": + elif mode == "dim0_mx_floor": to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference) y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE) @@ -159,7 +165,7 @@ def run( bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) - elif mode == "dim1_mx": + elif mode == "dim1_mx_floor": to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference) y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE) @@ -177,7 +183,7 @@ def run( bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) - elif mode == "dim1_mx_triton": + elif mode == "dim1_mx_triton_floor": y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE) for _ in range(2): @@ -194,6 +200,58 @@ def run( bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) + elif mode == "dim1_mx_cuda_floor": + from torchao.prototype import mxfp8_cuda + + _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="floor" + ) + + for _ in range(2): + __ = mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="floor" + ) + + time_us = benchmark_cuda_function_in_microseconds( + lambda x: mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="floor" + ), + x, + ) + + assert y_d1.dtype == torch.float8_e4m3fn + assert s_d1.dtype == torch.float8_e8m0fnu + + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + + elif mode == "dim1_mx_cuda_rceil": + from torchao.prototype import mxfp8_cuda + + _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="rceil" + ) + + for _ in range(2): + __ = mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="rceil" + ) + + time_us = benchmark_cuda_function_in_microseconds( + lambda x: mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="rceil" + ), + x, + ) + + assert y_d1.dtype == torch.float8_e4m3fn + assert s_d1.dtype == torch.float8_e8m0fnu + + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + else: raise AssertionError(f"unknown mode {mode}") diff --git a/setup.py b/setup.py index 88669e7b3b..2572151145 100644 --- a/setup.py +++ b/setup.py @@ -490,6 +490,14 @@ def get_extensions(): if use_cuda: sources += cuda_sources + # Add MXFP8 cuda extension dir + mxfp8_extension_dir = os.path.join(extensions_dir, "cuda", "mx_kernels") + mxfp8_sources_to_exclude = list( + glob.glob(os.path.join(mxfp8_extension_dir, "**/*"), recursive=True) + ) + sources = [s for s in sources if s not in mxfp8_sources_to_exclude] + print("sources after mxfp8 exclusion", sources) + # TOOD: Remove this and use what CUDA has once we fix all the builds. if use_rocm: # Add ROCm GPU architecture check @@ -610,6 +618,36 @@ def get_extensions(): ) ) + # Add the mxfp8 casting CUDA extension + if use_cuda: + mxfp8_sources = [ + os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"), + os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"), + ] + + # Only add the extension if the source files exist AND we are building for sm100 + mxfp8_src_files_exist = all(os.path.exists(f) for f in mxfp8_sources) + if mxfp8_src_files_exist and build_for_sm100a: + print("Building mxfp8_cuda extension") + ext_modules.append( + CUDAExtension( + name="torchao.prototype.mxfp8_cuda", + sources=mxfp8_sources, + include_dirs=[ + mxfp8_extension_dir, # For mxfp8_quantize.cuh, mxfp8_extension.cpp, and mxfp8_cuda.cu + "/usr/local/cuda-12.8/include", # CUDA 12.8 headers + ], + library_dirs=[ + "/usr/local/cuda-12.8/lib64", # CUDA 12.8 libraries + ], + extra_compile_args={ + "cxx": ["-std=c++17", "-O3"], + "nvcc": nvcc_args, + }, + extra_link_args=["-lcuda", "-lcudart"], + ), + ) + # Only build the cutlass_90a extension if sm90a is in the architecture flags if ( cutlass_90a_sources is not None diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index d649b2e04a..6b0aab129c 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -42,7 +42,7 @@ triton_to_mxfp8_dim1_reference, unpack_uint4, ) -from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, @@ -56,6 +56,15 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +# TODO: shared utils file for benchmarking and testing +def to_mx_dim1_reference(x_hp, block_size, scaling_mode): + x_hp = x_hp.t().contiguous() + scale_d1, data_d1 = to_mx( + x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode + ) + return data_d1.t(), scale_d1 + + @pytest.mark.skip( reason="TODO debug CI failure, low pri since this is not used in the MX code" # noqa: E501 ) @@ -488,3 +497,99 @@ def test_rearrange(shape): eager = to_blocked(scales, False) triton = to_blocked(scales, True) torch.testing.assert_close(eager, triton, atol=0, rtol=0) + + +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="MXFP8 requires CUDA capability 10.0 or greater", +) +@pytest.mark.parametrize("M", (32, 64, 2048)) +@pytest.mark.parametrize("K", (32, 64, 2048)) +@pytest.mark.parametrize("input_dtype", (torch.float32, torch.bfloat16)) +@pytest.mark.parametrize( + "scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL) +) +def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): + from torchao.prototype import mxfp8_cuda + + scaling_mode_str = ( + "floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil" + ) + block_size = 32 + + # Use disinct incrementing values from 0 to M*K-1 to make debugging easier. + x = ( + torch.arange(0, M * K, dtype=input_dtype, device="cuda") + .reshape(M, K) + .contiguous() + ) + + y_d1_ref, s_d1_ref = to_mx_dim1_reference( + x, + block_size=block_size, + scaling_mode=scaling_mode, + ) + + _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + x, + rowwise=False, + colwise=True, + scaling_mode=scaling_mode_str, + scale_dim_x=1, + scale_dim_y=block_size, + ) + + # check scales + torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0) + + # check quantized values + torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0) + assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match" + + +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="MXFP8 requires CUDA capability 10.0 or greater", +) +def test_cuda_mx_dim0_not_supported(): + from torchao.prototype import mxfp8_cuda + + M, K = 64, 64 + block_size = 32 + x = ( + torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + .reshape(M, K) + .contiguous() + ) + with pytest.raises(RuntimeError): + _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + x, + rowwise=True, + colwise=False, + scale_dim_x=block_size, + scale_dim_y=1, + ) + + +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="MXFP8 requires CUDA capability 10.0 or greater", +) +def test_cuda_mx_dim1_invalid_block_size(): + from torchao.prototype import mxfp8_cuda + + M, K = 64, 64 + x = ( + torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + .reshape(M, K) + .contiguous() + ) + invalid_block_size = 4 + with pytest.raises(RuntimeError): + _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + x, + rowwise=False, + colwise=True, + scale_dim_x=1, + scale_dim_y=invalid_block_size, + ) diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu b/torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu new file mode 100644 index 0000000000..ffb91d38c6 --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu @@ -0,0 +1,112 @@ +// CUDA bridge for MXFP8 quantization + +#include "mxfp8_quantize.cuh" +#include +#include +#include + + +namespace mxfp8 { + +// Convert PyTorch scalar type to our DType enum +DType get_input_dtype(const torch::Tensor &t) { + switch (t.scalar_type()) { + case torch::kFloat32: + return DType::kFloat32; + case torch::kFloat16: + return DType::kFloat16; + case torch::kBFloat16: + return DType::kBFloat16; + case torch::kUInt8: + return DType::kByte; + default: + TORCH_CHECK(false, "Unsupported input tensor dtype: ", t.scalar_type()); + } +} + +ScaleCalculationMode get_scaling_mode(const std::string &scaling_mode) { + if (scaling_mode.compare("floor") == 0) { + return ScaleCalculationMode::FLOOR; + } else if (scaling_mode.compare("rceil") == 0) { + return ScaleCalculationMode::RCEIL; + } else { + TORCH_CHECK(false, "Unsupported scaling mode: ", scaling_mode, ". Only ['floor', 'rceil'] are supported."); + } +} + +// Convert FP8 format string to DType enum +DType get_output_dtype(const std::string &fp8_format) { + if (fp8_format.compare("e4m3") == 0) { + return DType::kFloat8E4M3; + } else { + TORCH_CHECK(false, "Unsupported FP8 format: ", fp8_format, + ". Only 'e4m3' is supported."); + } +} + +void mxfp8_quantize_cuda(const torch::Tensor &input, + torch::Tensor &output_rowwise, + torch::Tensor &output_colwise, + torch::Tensor &scales_rowwise, + torch::Tensor &scales_colwise, + int64_t scale_dim_x, + int64_t scale_dim_y, + const std::string &fp8_format, + const std::string &scaling_mode) { + + // Get tensor properties + const int64_t rows = input.size(0); + const int64_t cols = input.size(1); + + // Get data pointers + const void *input_ptr = input.data_ptr(); + void *output_rowwise_ptr = + output_rowwise.numel() > 0 ? output_rowwise.data_ptr() : nullptr; + void *output_colwise_ptr = + output_colwise.numel() > 0 ? output_colwise.data_ptr() : nullptr; + e8m0_t *scales_rowwise_ptr = + scales_rowwise.numel() > 0 + ? reinterpret_cast(scales_rowwise.data_ptr()) + : nullptr; + e8m0_t *scales_colwise_ptr = + scales_colwise.numel() > 0 + ? reinterpret_cast(scales_colwise.data_ptr()) + : nullptr; + + // Get CUDA stream + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Get strides of scale ptrs + int64_t scale_rowwise_stride_dim0 = scales_rowwise.strides()[0]; + int64_t scale_rowwise_stride_dim1 = scales_rowwise.strides()[1]; + int64_t scale_colwise_stride_dim0 = scales_colwise.strides()[0]; + int64_t scale_colwise_stride_dim1 = scales_colwise.strides()[1]; + +#if defined(DEBUG) + printf("mxfp8_quantize_cuda:\n"); + printf("Quantizing input tensor of size %ld x %ld\n", rows, cols); + printf("scaling_mode: %s\n", scaling_mode.c_str()); + printf("Scale dim x: %ld\n", scale_dim_x); + printf("Scale dim y: %ld\n", scale_dim_y); + printf("Rowwise scale shape: %ld x %ld\n", scales_rowwise.sizes()[0], scales_rowwise.sizes()[1]); + printf("Colwise scale shape: %ld x %ld\n", scales_colwise.sizes()[0], scales_colwise.sizes()[1]); + printf("scale_rowwise_stride_dim0 = %ld\n", scale_rowwise_stride_dim0); + printf("scale_rowwise_stride_dim1 = %ld\n", scale_rowwise_stride_dim1); + printf("scale_colwise_stride_dim0 = %ld\n", scale_colwise_stride_dim0); + printf("scale_colwise_stride_dim1 = %ld\n", scale_colwise_stride_dim1); +#endif + + // Call the quantization kernel + MXFP8Quantizer::quantize(input_ptr, + output_rowwise_ptr, output_colwise_ptr, + scales_rowwise_ptr, scales_colwise_ptr, + scale_rowwise_stride_dim0, scale_rowwise_stride_dim1, + scale_colwise_stride_dim0, scale_colwise_stride_dim1, + rows, cols, + get_input_dtype(input), get_output_dtype(fp8_format), + scale_dim_x, scale_dim_y, + get_scaling_mode(scaling_mode), + stream); +} + +} // namespace mxfp8 diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp new file mode 100644 index 0000000000..1f76788133 --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp @@ -0,0 +1,128 @@ +// PyBind wrapping for the mxfp8 extension +#include +#include +#include +#include +#include + +namespace mxfp8 { + +// Forward declarations +void mxfp8_quantize_cuda(const torch::Tensor &input, + torch::Tensor &output_rowwise, + torch::Tensor &output_columnwise, + torch::Tensor &scales_rowwise, + torch::Tensor &scales_colwise, + int64_t scale_dim_x, + int64_t scale_dim_y, + const std::string &fp8_format, + const std::string &scaling_mode); + +// Helper for tensor validation +void check_cuda_tensor(const torch::Tensor &t, const char *name) { + TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); +} + +// Helper to validate FP8 format +void validate_fp8_format(const std::string &fp8_format) { + TORCH_CHECK(fp8_format.compare("e4m3") == 0, + "fp8_format must be 'e4m3', got: ", fp8_format); +} + +// Helper to validate scale dimensions +void validate_scale_dimensions(int64_t scale_dim_x, int64_t scale_dim_y) { + TORCH_CHECK(scale_dim_x == 1 || scale_dim_x == 32, + "scale_dim_x must be 1 or 32, got: ", scale_dim_x); + TORCH_CHECK(scale_dim_y == 1 || scale_dim_y == 32, + "scale_dim_y must be 1 or 32, got: ", scale_dim_y); +} + +// Main quantization function +std::tuple +mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise, + int64_t scale_dim_x, int64_t scale_dim_y, + const std::string &fp8_format, + const std::string &scaling_mode) { + + // Validate inputs + TORCH_CHECK(!rowwise, "rowwise scaling is not supported yet"); + check_cuda_tensor(input, "input"); + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + TORCH_CHECK(input.scalar_type() == torch::kFloat32 || + input.scalar_type() == torch::kFloat16 || + input.scalar_type() == torch::kBFloat16, + "Input must be float32, float16, or bfloat16"); + TORCH_CHECK(rowwise || colwise, + "At least one of rowwise or colwise must be true"); + + validate_scale_dimensions(scale_dim_x, scale_dim_y); + validate_fp8_format(fp8_format); + + const int64_t rows = input.size(0); + const int64_t cols = input.size(1); + TORCH_CHECK((rows >= 32) && (rows % 32 == 0), "rows must be a multiple of 32"); + TORCH_CHECK((cols >= 32) && (cols % 32 == 0), "cols must be a multiple of 32"); + + c10::cuda::CUDAGuard device_guard(input.device()); + + // Create tensor options + const auto options_fp8 = torch::TensorOptions() + .dtype(torch::kFloat8_e4m3fn) // FP8 stored as uint8 + .device(input.device()); + + const auto options_scale = torch::TensorOptions() + .dtype(torch::kFloat8_e8m0fnu) // E8M0 stored as uint8 + .device(input.device()); + + // Allocate output tensors + torch::Tensor output_rowwise, output_colwise; + torch::Tensor scales_rowwise, scales_colwise; + + if (rowwise) { + const int64_t num_col_blocks = (cols + scale_dim_x - 1) / scale_dim_x; + output_rowwise = torch::empty({rows, cols}, options_fp8); + scales_rowwise = torch::empty({rows, num_col_blocks}, options_scale); + } else { + output_rowwise = torch::empty({0}, options_fp8); + scales_rowwise = torch::empty({0}, options_scale); + } + + if (colwise) { + const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y; + output_colwise = torch::empty_strided({rows, cols}, {1, rows}, options_fp8); + // Need scales_colwise to be this shape so the 'col' dim stride is 1, + // for colwise scaling, we can avoid uncoalesced writes to global memory. + // This is because each of the 32 threads in a warp will be computing + // a scale for a different column of 32 input data values, then each writing + // that scale to global memory - so the stride along this `col` dim should be 1 + // so writes can be coalesced into a single transaction. + scales_colwise = torch::empty_strided({cols, num_row_blocks}, {1, cols}, options_scale); + } else { + output_colwise = torch::empty({0}, options_fp8); + scales_colwise = torch::empty({0}, options_scale); + } + + // Call CUDA kernels + mxfp8_quantize_cuda(input, + output_rowwise, output_colwise, + scales_rowwise, scales_colwise, + rowwise ? scale_dim_x : 1, // scale_dim_x + colwise ? scale_dim_y : 1, // scale_dim_y + fp8_format, scaling_mode); + + return std::make_tuple(output_rowwise, output_colwise, scales_rowwise, + scales_colwise); +} + +} // namespace mxfp8 + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "MXFP8 Quantization PyTorch Extension"; + + m.def("quantize", &mxfp8::mxfp8_quantize, "MXFP8 quantization", + py::arg("input"), py::arg("rowwise") = true, py::arg("colwise") = false, + py::arg("scale_dim_x") = 32, py::arg("scale_dim_y") = 32, + py::arg("fp8_format") = "e4m3", + py::arg("scaling_mode") = "floor"); +} diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh b/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh new file mode 100644 index 0000000000..9b86c680d0 --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh @@ -0,0 +1,1049 @@ +// Adapted from https://github.com/NVIDIA/TransformerEngine +// License - Apache-2.0 +// https://github.com/NVIDIA/TransformerEngine/blob/main/LICENSE +// * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Portions (c) Meta Platforms, Inc. and affiliates. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Use official CUDA PTX library +#include "ptx.cuh" +#include +#include + +#define MIN_CUDA_SM 1000 // SM90 = 900, SM100 = 1000 + +// Check if we're compiling for supported architecture +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < MIN_CUDA_SM) +#warning \ + "MXFP8 quantization requires SM90+ (Hopper) or SM100+ (Blackwell) architecture. Kernel will be disabled for this architecture." +#endif + +// Architecture detection for native FP8 support +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 +#define HAS_NATIVE_FP8_CONVERSION 1 +#else +#define HAS_NATIVE_FP8_CONVERSION 0 +#endif + +enum class DType { + kByte, + kFloat32, + kFloat16, + kBFloat16, + kFloat8E4M3, + kFloat8E5M2 +}; + +enum class ScaleCalculationMode { + FLOOR, // uses software scaling + RCEIL, // uses hardware scaling +}; + +// Data types +using e8m0_t = uint8_t; +using bfloat16 = nv_bfloat16; +using fp8e4m3 = __nv_fp8_e4m3; + +constexpr size_t get_dtype_bits(DType dtype) { + switch (dtype) { + case DType::kFloat32: + return 32; + case DType::kBFloat16: + return 16; + case DType::kFloat8E4M3: + return 8; + default: + // TODO: something smarter than this + return 0; + } +} + +// FP32 constants +constexpr int32_t FP32_MANTISSA_BITS = 23; +constexpr int32_t FP32_EXPONENT_BIAS = 127; + +// BF16 constants +constexpr int32_t BF16_MANTISSA_BITS = 7; +constexpr int32_t BF16_EXPONENT_BIAS = 127; + +// FP8E4M3 constants +constexpr int32_t F8E4M3_MAX_POW2 = 8; +constexpr float F8E4M3_MAX = 448.0; + +// FP8E8M0 constants +constexpr int32_t E8M0_EXPONENT_BIAS = 127; + +// 1. Base template (for unsupported types) +template struct DataTypeTraits { + static constexpr bool is_supported = false; +}; + +// 2. Specialization for float32 +template <> struct DataTypeTraits { + static constexpr bool is_supported = true; + static constexpr int mantissa_bits = 23; + static constexpr int exponent_bias = 127; + + __device__ static __forceinline__ float to_float(const float val) { + return val; + } +}; + +// 3. Specialization for bfloat16 +template <> struct DataTypeTraits { + static constexpr bool is_supported = true; + static constexpr int mantissa_bits = 7; + static constexpr int exponent_bias = 127; + + __device__ static __forceinline__ float to_float(const nv_bfloat16 val) { + return __bfloat162float(val); + } +}; + +__device__ static __forceinline__ e8m0_t +calculate_e8m0_biased_scale(const float amax) { + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L239 + const int32_t int_amax = *reinterpret_cast(&amax); + const int32_t extracted_pow2 = + ((int_amax >> FP32_MANTISSA_BITS) & 0b11111111) - FP32_EXPONENT_BIAS; + + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L244 + int32_t scale_unbiased = extracted_pow2 - F8E4M3_MAX_POW2; + + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L256 + scale_unbiased = max(scale_unbiased, -E8M0_EXPONENT_BIAS); + scale_unbiased = min(scale_unbiased, E8M0_EXPONENT_BIAS + 1); + int32_t scale_with_e8m0_bias = scale_unbiased + E8M0_EXPONENT_BIAS; + + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L261C9-L261C26 + const e8m0_t e8m0_biased_scale = + *reinterpret_cast(&scale_with_e8m0_bias); + return e8m0_biased_scale; +} + +// Constants for MXFP8 kernel +constexpr size_t MXFP8_CHUNK_DIM_Y = 64; +constexpr size_t MXFP8_CHUNK_DIM_X = 64; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK = + MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; // 1 * 1 = 1 +constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; +constexpr size_t MXFP8_BUFFERS_NUM = 2; +constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t MXFP8_BUFFER_DIM_Y = 32; +constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 +constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = + MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 64/16 = 4 +constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = + MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 64 / 4 = 16 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_BUFF_STAGES_NUM = + MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 +constexpr size_t MXFP8_ITERATIONS = + MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 +static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); + +constexpr size_t THREADS_PER_WARP = 32; // lol + +// Utility macros +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) + +// Vector type for loading/storing multiple elements +template struct Vec { + union { + T elt[N]; + } data; + + __device__ inline void clear() { +#pragma unroll + for (int i = 0; i < N; ++i) { + data.elt[i] = T(0); + } + } + + __device__ inline void load_from(const T *ptr) { +#pragma unroll + for (int i = 0; i < N; ++i) { + data.elt[i] = ptr[i]; + } + } + + __device__ inline void store_to(T *ptr) const { +#pragma unroll + for (int i = 0; i < N; ++i) { + ptr[i] = data.elt[i]; + } + } +}; + +// Source: +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L971 +__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { + return (biased_exp == 0) + ? 1 + : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); +} + +// Source: +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L937 +__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } +#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) + uint16_t out; + asm volatile("{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && + !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#endif +} + +// Quantization limits +// Source: +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L929 +template struct Quantized_Limits { + static constexpr float max_norm = 448.0f; // For E4M3 + static constexpr float max_norm_rcp = 1.0f / max_norm; +}; + +// Warp reduction utilities +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L867 +/** + * Max reduction in subwarps + * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four + * MXFP8 scaling factors. To compute an actual scaling factor for 32 + * consequentive elements, only 8 threads need to participate, thus splitting + * the warp into 4x smaller subwarps 8-thread width. 'Butterfly' reduction is + * used inside subwarps. + */ +template +__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) { + float val_tmp = val; +#pragma unroll + for (int offset = subwarp_width / 2; offset > 0; offset /= 2) { + const float val_other = + __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width); + __builtin_assume(val_tmp >= 0); + __builtin_assume(val_other >= 0); + val_tmp = fmaxf(val_tmp, val_other); + } + // Broadcast the amax to other threads of the subwarp from the zero subwarp + // lane_id + constexpr int subwarp_lane_zero = 0; + val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width); + return val_tmp; +} + +// Source: +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L813C1-L824C2 +template +__device__ __forceinline__ float warp_reduce_max(const float m) { + float tmp = m; +#pragma unroll + for (int delta = num_elems / 2; delta > 0; delta /= 2) { + const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta); + __builtin_assume(tmp >= 0); + __builtin_assume(other_m >= 0); + tmp = fmaxf(tmp, other_m); + } + return tmp; +} + +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L841C1-L857C2 +template +__device__ __forceinline__ compute_t reduce_max(const compute_t m, + const int warpid) { + __shared__ float staging[num_warps]; + constexpr int warp_size = 32; + const float my_max = m; + const float my_warp_max = warp_reduce_max(my_max); + if (threadIdx.x % 32 == 0) { + staging[warpid] = my_warp_max; + } + __syncthreads(); + compute_t result = 0.f; + if (warpid == 0) { + const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; + result = warp_reduce_max(my_max); + } + return result; +} + +// https://stackoverflow.com/a/51549250 +// TODO: handle -0 case +__device__ __forceinline__ float atomicMaxFloat(float *addr, float value) { + float old; + old = (value >= 0) + ? __int_as_float(atomicMax((int *)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int *)addr, __float_as_uint(value))); + + return old; +} + +// TMA descriptor creation +inline CUtensorMapDataType get_dtype_for_tma(DType dtype) { + switch (dtype) { + case DType::kFloat32: + return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + case DType::kFloat16: + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + case DType::kBFloat16: + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + case DType::kFloat8E4M3: + case DType::kFloat8E5M2: + case DType::kByte: + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + default: + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } +} + +// Reference: +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/common.cu#L137 +// This was modified to make it compatible with our implementation and avoid +// using internal TE types. +inline void create_2D_tensor_map(CUtensorMap &tensorMap, void *data_ptr, + DType dtype, const size_t rows, + const size_t cols, uint32_t shmem_y, + uint32_t shmem_x, const size_t stride_elems, + const size_t type_num_bits) { + // Get function pointer to cuTensorMapEncodeTiled + static void *driver_ptr = nullptr; + if (!driver_ptr) { + cudaDriverEntryPointQueryResult result; + cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, + cudaEnableDefault, &result); + } + auto cuTensorMapEncodeTiled = + reinterpret_cast(driver_ptr); + + constexpr uint32_t rank = 2; + uint64_t size[rank] = {cols, rows}; + uint64_t stride[rank - 1] = {(stride_elems * type_num_bits) / + 8}; // (cols * bits per element) / 8 + uint32_t boxSize[rank] = {shmem_x, shmem_y}; + uint32_t elemStride[rank] = {1, 1}; + +#if defined(DEBUG) + printf("TMA Descriptor: global_shape=(%llu, %llu), tile_shape=(%u, %u), " + "stride_bytes=%llu\n", + (unsigned long long)size[1], (unsigned long long)size[0], boxSize[1], + boxSize[0], (unsigned long long)stride[0]); +#endif + + cuTensorMapEncodeTiled( + &tensorMap, get_dtype_for_tma(dtype), rank, data_ptr, size, stride, + boxSize, elemStride, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); +} + +// Helper functions for TMA operations +__device__ inline void copy_2d_to_shared(void *smem, + const CUtensorMap *tensor_map, + uint32_t x, uint32_t y, + size_t smem_size, uint64_t *mbar, + bool is_master) { + if (is_master) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(smem), + reinterpret_cast(tensor_map), x, y, mbar); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(mbar, smem_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(mbar); + } +} + +//////////////////////////////////////////////////////////////////////////////// +// TorchAO shared quantization utils +//////////////////////////////////////////////////////////////////////////////// + +/** + * Convert e8m0 biased scale to float32 scale following torchao implementation + * torchao ref: + * https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L275C1-L277C30 + */ +__device__ __forceinline__ float e8m0_to_scale_fp32(e8m0_t e8m0_biased_scale) { + int32_t exponent_as_int32 = static_cast(e8m0_biased_scale); + int32_t float_bits = exponent_as_int32 << FP32_MANTISSA_BITS; + float scale_fp32 = *reinterpret_cast(&float_bits); + + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L286 + const float F32_MIN_NORMAL = exp2f(-FP32_EXPONENT_BIAS + 1); + scale_fp32 = max(scale_fp32, F32_MIN_NORMAL); + + return scale_fp32; +} + +/** + * Quantize a single value using torchao-style clamping and conversion + * torchao ref: + * https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L289 + */ +template +__device__ __forceinline__ OType torchao_quantize_value(float input_value, + float inv_scale_fp32) { + // Scale the input value + float data_lp = input_value * inv_scale_fp32; + + // Apply torchao-style clamping + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L301C23-L301C74 + data_lp = min(data_lp, F8E4M3_MAX); + data_lp = max(data_lp, -F8E4M3_MAX); + + return static_cast(data_lp); +} + +/** + * Complete torchao-style quantization: calculate scale and convert values + * Template parameters ensure compile-time array size checking for safety + */ +template +__device__ __forceinline__ float +quantize_block(float amax, e8m0_t &out_scale, + const float (&input_values)[NUM_VALUES], + OType (&output_values)[NUM_VALUES]) { + + float inv_scale_fp32; + if constexpr (ScalingMode == ScaleCalculationMode::FLOOR) { + // FLOOR scaling. + out_scale = calculate_e8m0_biased_scale(amax); + + // Convert scale to float32 + float scale_fp32 = e8m0_to_scale_fp32(out_scale); + + // Calculate inverse scale for fast multiplication + inv_scale_fp32 = __fdiv_rn(1.0f, scale_fp32); + + // Quantize all values +#pragma unroll + for (int i = 0; i < NUM_VALUES; ++i) { + output_values[i] = + torchao_quantize_value(input_values[i], inv_scale_fp32); + } + + } else { + // RCEIL scaling. + out_scale = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + inv_scale_fp32 = exp2f_rcp(out_scale); + +#pragma unroll + for (int i = 0; i < NUM_VALUES; ++i) { + output_values[i] = + static_cast(input_values[i] * inv_scale_fp32); + } + } + +} + +/** + * Bounds checking helper for IMA avoidance + */ +struct BoundsChecker { + const size_t rows, cols; + const size_t chunk_offset_X, chunk_offset_Y; + + __device__ __forceinline__ BoundsChecker(size_t r, size_t c, size_t cox, + size_t coy) + : rows(r), cols(c), chunk_offset_X(cox), chunk_offset_Y(coy) {} + + __device__ __forceinline__ bool is_out_of_bounds(size_t row, + size_t col) const { + return (row >= rows) || (col >= cols); + } + + __device__ __forceinline__ bool + is_rowwise_out_of_bounds(size_t shmem_y, size_t shmem_x, int j, + size_t row_base) const { + const size_t row = row_base + shmem_y; + const size_t col = chunk_offset_X + shmem_x + j; + return is_out_of_bounds(row, col); + } + + __device__ __forceinline__ bool + is_colwise_out_of_bounds(size_t row_offset, size_t col, + size_t row_base) const { + const size_t row = row_base + row_offset; + return is_out_of_bounds(row, col); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// MXFP8 quantization kernel +//////////////////////////////////////////////////////////////////////////////// + +// Main MXFP8 quantization kernel (with TMA) +template +__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + mxfp8_quantize_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, + const size_t scales_rowwise_stride_dim0, + const size_t scales_rowwise_stride_dim1, + const size_t scales_colwise_stride_dim0, + const size_t scales_colwise_stride_dim1) { + +#if defined(DEBUG) + printf("mxfp8_quantize_kernel: rows=%llu, cols=%llu, " + "scales_rowwise_stride_dim0=%llu, scales_rowwise_stride_dim1=%llu, " + "scales_colwise_stride_dim0=%llu, scales_colwise_stride_dim1=%llu\n", + (unsigned long long)rows, (unsigned long long)cols, + (unsigned long long)scales_rowwise_stride_dim0, + (unsigned long long)scales_rowwise_stride_dim1, + (unsigned long long)scales_colwise_stride_dim0, + (unsigned long long)scales_colwise_stride_dim1); + + if (ScalingMode == ScaleCalculationMode::FLOOR) { + printf("mxfp8_quantize_kernel: scaling_mode: floor\n"); + } else if (ScalingMode == ScaleCalculationMode::RCEIL) { + printf("mxfp8_quantize_kernel: scaling_mode: rceil\n"); + } else { + printf("mxfp8_quanitze_kenrel: unknown scaling mode\n"); + } +#endif + + + static_assert(DataTypeTraits::is_supported, + "Input data type is not supported by this kernel."); + + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = + MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = + MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = + SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = + SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = + MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = + MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int block_offset_Y = + blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; + const int block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int scales_rowwise_block_offset_Y = + blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; + const int scales_rowwise_block_offset_X = + blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; + const int scales_colwise_block_offset_Y = + blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = + blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + + // The destination shared memory buffer of a bulk tensor operation should be + // 128 e8m0_t aligned + __shared__ alignas(128) + IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType + out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType + out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_X][MXFP8_SHMEM_DIM_Y]; + + constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + float block_amax = 0; + +// Initialize shared memory barrier with the number of threads participating in +// the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + + initialize_barriers( + mbar, is_master_thread); + + int parity = 0; + +// Process chunks +#pragma unroll + // Calculate chunk offsets + for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { + const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; + const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + + const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; + const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = + scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = + scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = + scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = + scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + +// Prefetch initial data +#pragma unroll + // Kick off TMA async copy from global to shared memory + for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; + ++prefetch_buff) { + const int chunk_stage_offset_Y = + chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, + chunk_stage_offset_X, chunk_stage_offset_Y, + shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + +// Process iterations +#pragma unroll + // Iterate through the chunk along the Y dim + for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { + const int buff = iter % MXFP8_BUFFERS_NUM; + const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + // Prefetch next iteration data + if (next_iter < MXFP8_ITERATIONS) { + const int next_buff = next_iter % MXFP8_BUFFERS_NUM; + const int chunk_it_offset_y = + chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, + &mbar[next_iter], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#if defined(DEBUG_SMEM) + // Debugging smem data + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) { + printf("Shared memory values:\n"); + for (int b = 0; b < MXFP8_BUFFERS_NUM; b++) { + for (int y = 0; y < MXFP8_SHMEM_DIM_Y; y++) { + for (int x = 0; x < MXFP8_SHMEM_DIM_X; x++) { + printf("in_sh[%d][%d][%d] = %f\n", b, y, x, + (float)in_sh[b][y][x]); + } + } + } + } +#endif + + // ======== RowWise SCALING ======== + + // Updated Row-wise scaling section: + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec out_c; + + // Create bounds checker for this chunk + BoundsChecker bounds(rows, cols, chunk_offset_X, chunk_offset_Y); + + const int iteration_scale_rowwise_offset_Y = + scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + +#pragma unroll + for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + + // Load from shared memory into thread local registers + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; + + // Calculate thread-local amax and prepare input values +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + const bool out_of_bounds = bounds.is_rowwise_out_of_bounds( + shmem_offset_y, shmem_offset_x, j, row_base); + + // Load and convert to float + float elt = DataTypeTraits::to_float(in.data.elt[j]); + in_compute[j] = elt; + + // Update thread local amax + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + // Update block local amax + block_amax = fmaxf(block_amax, thread_amax); + + // Reduce amax across subwarp + const float subwarp_amax = + subwarp_reduce_max_broadcast(thread_amax); + + + // Apply quantization to the local block. + e8m0_t e8m0_biased_scale; + OType quantized_values[ELEMS_PER_THREAD]; + + quantize_block( + subwarp_amax, e8m0_biased_scale, in_compute, quantized_values); + + // Write scaling factor (only a single thread writes it to global + // memory) + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + + tid_rowwise_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + const int scale_idx = + global_scales_offset_Y * scales_rowwise_stride_dim0 + + global_scales_offset_X; + scales_rowwise[scale_idx] = e8m0_biased_scale; + } + + // Store quantized values +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out_c.data.elt[j] = quantized_values[j]; + } + out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); + +#if defined(DEBUG) + if (tid_rowwise_X == 0 && tid_rowwise_Y == 0) { + printf("Rowwise: subwarp_amax=%f, e8m0_scale=%u\n", subwarp_amax, e8m0_biased_scale); + } +#endif + + } + } + // ======== End RowWise SCALING ======== + + // ======== ColWise SCALING ======== + // Column-wise scaling + + if constexpr (USE_COLWISE_SCALING) { + // Create bounds checker for this chunk + BoundsChecker bounds(rows, cols, chunk_offset_X, chunk_offset_Y); + + const size_t col = chunk_offset_X + tid_colwise_X; + const bool col_out_of_bounds = (col >= cols); + + float in_compute[SCALE_DIM_Y]; + float amax = 0; + + // Calculate amax and prepare input values +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const bool out_of_bounds = + bounds.is_colwise_out_of_bounds(i, col, row_base); + + // Load and convert to float + float elt = + DataTypeTraits::to_float(in_sh[buff][i][tid_colwise_X]); + in_compute[i] = elt; + + // Update thread local amax + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } + + // Apply quantization to the local block. + e8m0_t e8m0_biased_scale; + OType quantized_values[SCALE_DIM_Y]; + quantize_block( + amax, e8m0_biased_scale, in_compute, quantized_values); + + // Write scaling factor to global memory + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; + const int global_scales_offset_X = + scales_colwise_chunk_offset_X + tid_colwise_X; + + // Write scale in column major memory layout, shape (cols, num_row_blocks, 1). + // Stride along `cols` dim must be 1, for coalesced writes to global memory. + const int scale_idx = + global_scales_offset_Y * scales_colwise_stride_dim1 + + global_scales_offset_X * scales_colwise_stride_dim0; + + // Bounds check for scale writing + const bool row_out_of_bounds = (row_base >= rows); + if (!row_out_of_bounds && !col_out_of_bounds) { + scales_colwise[scale_idx] = e8m0_biased_scale; + } + + // Store quantized values to shared memory +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + out_colwise_sh[buff][tid_colwise_X][i] = quantized_values[i]; + } + +#if defined(DEBUG) + if (tid_colwise_X == 0) { + printf("Colwise: amax=%f, e8m0_scale=%u\n", amax, e8m0_biased_scale); + } +#endif + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + if constexpr (USE_ROWWISE_SCALING) { + const int chunk_it_offset_y = + chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), + chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(&out_rowwise_sh[buff])); + } + if constexpr (USE_COLWISE_SCALING) { + // Swap logical destination offsets for TMA to write into column major layout. + const int chunk_it_offset_y = chunk_offset_X; + const int chunk_it_offset_x = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), + chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(&out_colwise_sh[buff])); + } + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + } + + destroy_barriers(mbar, is_master_thread); + // #endif +} + +// Simple wrapper class for MXFP8 quantization +class MXFP8Quantizer { +public: + // Quantize a tensor using MXFP8 + // input: pointer to input data + // output_rowwise: pointer to row-wise quantized output (can be nullptr) + // output_colwise: pointer to column-wise quantized output (can be nullptr) + // scales_rowwise: pointer to row-wise scaling factors (required if + // output_rowwise is not null) scales_colwise: pointer to column-wise scaling + // factors (required if output_colwise is not null) rows, cols: tensor + // dimensions input_dtype: data type of input output_dtype: FP8 output type + // (fp8e4m3 or fp8e5m2) scale_dim_x: block size for row-wise scaling + // (typically 32) scale_dim_y: block size for column-wise scaling (typically + // 32) + static void + quantize(const void *input, void *output_rowwise, void *output_colwise, + e8m0_t *scales_rowwise, e8m0_t *scales_colwise, + size_t scales_rowwise_stride_dim0, size_t scales_rowwise_stride_dim1, + size_t scales_colwise_stride_dim0, size_t scales_colwise_stride_dim1, + size_t rows, size_t cols, DType input_dtype, DType output_dtype, + size_t scale_dim_x = 32, size_t scale_dim_y = 32, + ScaleCalculationMode scaling_mode = ScaleCalculationMode::FLOOR, + cudaStream_t stream = 0) { + + // Check parameters + assert((scale_dim_x == 1 || scale_dim_x == 32) && + (scale_dim_y == 1 || scale_dim_y == 32)); + assert(output_rowwise != nullptr || output_colwise != nullptr); + + if (output_rowwise) + assert(scales_rowwise != nullptr); + if (output_colwise) + assert(scales_colwise != nullptr); + + // Calculate grid dimensions + const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + const dim3 block(MXFP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + // Create TMA descriptors + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + int32_t input_bits_per_elem = get_dtype_bits(input_dtype); + int32_t output_bits_per_elem = get_dtype_bits(output_dtype); + + create_2D_tensor_map(tensor_map_input, const_cast(input), + input_dtype, + rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, + cols, // stride of "slowest moving" dim + input_bits_per_elem); // bits per elem in input + + if (output_rowwise) { + create_2D_tensor_map( + tensor_map_output_rowwise, output_rowwise, output_dtype, + rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, + cols, // stride of "slowest moving" dim + output_bits_per_elem); // bits per elem in output fp8e4m3 + } + + if (output_colwise) { + create_2D_tensor_map( + tensor_map_output_colwise, output_colwise, output_dtype, + cols, rows, // Swap for column major layout + MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, + rows, // stride of "slowest moving" dim + output_bits_per_elem); // bits per elem in output fp8e4m3 + } + +// Launch kernel based on input/output types and scaling dimensions +// Only compile kernel launches for SM90+ +#if defined(__CUDACC__) && \ + (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= MIN_CUDA_SM) + + // Use TMA and mbarrier instructions +#define LAUNCH_KERNEL(IType, OType, SCALE_Y, SCALE_X, ScalingMode) \ + mxfp8_quantize_kernel \ + <<>>( \ + tensor_map_input, tensor_map_output_rowwise, \ + tensor_map_output_colwise, scales_rowwise, scales_colwise, rows, \ + cols, scales_rowwise_stride_dim0, scales_rowwise_stride_dim1, \ + scales_colwise_stride_dim0, scales_colwise_stride_dim1); + + // Validate output dtype. + if (output_dtype != DType::kFloat8E4M3) { + printf("unsupported output dtype, must be fp8e4m3\n"); + exit(1); + } + + if (scaling_mode == ScaleCalculationMode::FLOOR) { + if (input_dtype == DType::kFloat32) { + if (scale_dim_x == 32 && scale_dim_y == 32) { + LAUNCH_KERNEL(float, fp8e4m3, 32, 32, ScaleCalculationMode::FLOOR); + } else if (scale_dim_x == 32 && scale_dim_y == 1) { + LAUNCH_KERNEL(float, fp8e4m3, 1, 32, ScaleCalculationMode::FLOOR); + } else if (scale_dim_x == 1 && scale_dim_y == 32) { + LAUNCH_KERNEL(float, fp8e4m3, 32, 1, ScaleCalculationMode::FLOOR); + } + } else if (input_dtype == DType::kBFloat16) { + if (scale_dim_x == 32 && scale_dim_y == 32) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 32, ScaleCalculationMode::FLOOR); + } else if (scale_dim_x == 32 && scale_dim_y == 1) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 1, 32, ScaleCalculationMode::FLOOR); + } else if (scale_dim_x == 1 && scale_dim_y == 32) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 1, ScaleCalculationMode::FLOOR); + } + } else { + printf("unsupported input dtype, must be float32 or bfloat16\n"); + exit(1); + } + } else if (scaling_mode == ScaleCalculationMode::RCEIL) { + if (input_dtype == DType::kFloat32) { + if (scale_dim_x == 32 && scale_dim_y == 32) { + LAUNCH_KERNEL(float, fp8e4m3, 32, 32, ScaleCalculationMode::RCEIL); + } else if (scale_dim_x == 32 && scale_dim_y == 1) { + LAUNCH_KERNEL(float, fp8e4m3, 1, 32, ScaleCalculationMode::RCEIL); + } else if (scale_dim_x == 1 && scale_dim_y == 32) { + LAUNCH_KERNEL(float, fp8e4m3, 32, 1, ScaleCalculationMode::RCEIL); + } + } else if (input_dtype == DType::kBFloat16) { + if (scale_dim_x == 32 && scale_dim_y == 32) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 32, ScaleCalculationMode::RCEIL); + } else if (scale_dim_x == 32 && scale_dim_y == 1) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 1, 32, ScaleCalculationMode::RCEIL); + } else if (scale_dim_x == 1 && scale_dim_y == 32) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 1, ScaleCalculationMode::RCEIL); + } + } else { + printf("unsupported input dtype, must be float32 or bfloat16\n"); + exit(1); + } + } else { + printf("unsupported scaling mode\n"); + exit(1); + } + +#undef LAUNCH_KERNEL + +#endif + } +}; diff --git a/torchao/csrc/cuda/mx_kernels/ptx.cuh b/torchao/csrc/cuda/mx_kernels/ptx.cuh new file mode 100644 index 0000000000..ba06746dbd --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/ptx.cuh @@ -0,0 +1,290 @@ +// Adapted from https://github.com/NVIDIA/TransformerEngine +// License - Apache-2.0 +// https://github.com/NVIDIA/TransformerEngine/blob/main/LICENSE +// * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Portions (c) Meta Platforms, Inc. and affiliates. + +/*! \file ptx.cuh + * \brief BW PTX + */ + +#include +#include + + +namespace ptx { + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init +__device__ __forceinline__ void mbarrier_init(uint64_t *mbar, + const uint32_t count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval +__device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void +mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), + "r"(tx_count) + : "memory"); +} + +__device__ __forceinline__ void fence_mbarrier_init_release_cluster() { + asm volatile("fence.mbarrier_init.release.cluster;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void +cp_async_bulk_tensor_1d_global_to_shared(uint64_t *dst_shmem, + const uint64_t *src_global_ptr, + const uint32_t size, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile("cp.async.bulk.shared::cta.global" + ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"( + dst_shmem_ptr), + "l"(src_global_ptr), "r"(size), "r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( + uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, + const uint32_t offset_x, const uint32_t offset_y, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"( + dst_shmem_ptr), + "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) + : "memory"); +} + +__device__ __forceinline__ bool +mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { + uint32_t waitComplete; + asm volatile("{\n\t .reg .pred P_OUT; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P_OUT; \n" + "}" + : "=r"(waitComplete) + : "r"(mbar_ptr), "r"(parity) + : "memory"); + return static_cast(waitComplete); +} + +__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, + const uint32_t parity) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global( + uint64_t *dst_global_ptr, const uint64_t *src_shmem, const uint32_t size) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile( + "cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"( + dst_global_ptr), + "r"(src_shmem_ptr), "r"(size) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( + const uint64_t *tensor_map_ptr, const uint32_t offset_x, + const uint32_t offset_y, uint64_t *src_shmem) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, " + "{%1, %2}], [%3];" ::"l"(tensor_map_ptr), + "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +__device__ __forceinline__ void cp_async_bulk_wait_group() { + asm volatile("cp.async.bulk.wait_group 0;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +template +__device__ __forceinline__ void cp_async_bulk_wait_group_read() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} + +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { + asm volatile("cp.async.bulk.wait_group.read 1;"); +} +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { + asm volatile("cp.async.bulk.wait_group.read 2;"); +} +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { + asm volatile("cp.async.bulk.wait_group.read 4;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +__device__ __forceinline__ void cp_async_bulk_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} + +// Proxy fence (bi-directional): +__device__ __forceinline__ void fence_proxy_async() { + asm volatile("fence.proxy.async;"); +} + +__device__ __forceinline__ void fence_proxy_async_shared_cta() { + asm volatile("fence.proxy.async.shared::cta;"); +} + +} // namespace ptx + +namespace { + +template +__forceinline__ __device__ void +initialize_barriers(uint64_t *mbar, const bool is_master_thread) { + if (is_master_thread) { + // Initialize barrier. All `blockDim.x * blockDim.y` threads in block + // participate. +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); +} + +template +__forceinline__ __device__ void destroy_barriers(uint64_t *mbar, + const bool is_master_thread) { + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +} + +__forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, + const size_t num_bytes, + uint64_t *barrier, + const bool is_master_thread) { + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_1d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), num_bytes, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +} + +__forceinline__ __device__ void +copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, + const size_t chunk_Y, const size_t num_bytes, + uint64_t *barrier, const bool is_master_thread) { + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), chunk_X, chunk_Y, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +} + +__forceinline__ __device__ void copy_2d_to_sharedx2( + void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, + void *dst2, const void *src2, const size_t chunk_X2, const size_t chunk_Y2, + const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst2), + reinterpret_cast(src2), chunk_X2, chunk_Y2, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +} + +__forceinline__ __device__ void copy_2d_to_sharedx3( + void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, + void *dst2, const void *src2, const size_t chunk_X2, const size_t chunk_Y2, + void *dst3, const void *src3, const size_t chunk_X3, const size_t chunk_Y3, + const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst2), + reinterpret_cast(src2), chunk_X2, chunk_Y2, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst3), + reinterpret_cast(src3), chunk_X3, chunk_Y3, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 3 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +} +} // namespace diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index e1e37ea7fa..420c83f81d 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1375,17 +1375,21 @@ def custom_triton_to_mxfp8_dim1_sharding(x, inner_block_size=32): return acceptable_shardings def triton_to_mxfp8_dim1_reference( - x_hp: torch.Tensor, block_size + x_hp: torch.Tensor, + block_size, + scaling_mode="FLOOR", ) -> Tuple[torch.Tensor, torch.Tensor]: """ A reference version of `to_mxfp8_dim1`. """ - from torchao.prototype.mx_formats.mx_tensor import to_mx + from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx + + scale_mode = ScaleCalculationMode[scaling_mode] # cast across dim1 x_hp_d1 = x_hp.t().contiguous() scale_e8m0_dim1, x_hp_d1_normalized = to_mx( - x_hp_d1, torch.float8_e4m3fn, block_size + x_hp_d1, torch.float8_e4m3fn, block_size, scaling_mode=scale_mode ) scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu) return ( @@ -1718,7 +1722,7 @@ def triton_to_mxfp8_dim1( raise AssertionError("needs torch version 2.8+ and triton") def triton_to_mxfp8_dim1_reference( - x_hp: torch.Tensor, block_size + x_hp: torch.Tensor, block_size, scaling_mode ) -> Tuple[torch.Tensor, torch.Tensor]: raise AssertionError("needs torch version 2.8+ and triton")