diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py index 46380cfb55..a7600c12ea 100644 --- a/test/prototype/mx_formats/test_mx_mm.py +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -10,10 +10,12 @@ from torchao.float8.float8_utils import compute_error from torchao.ops import mx_fp4_bf16 +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, + is_ROCm_mx_supported, is_sm_at_least_100, ) @@ -57,6 +59,41 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float: return compute_error(out_hp, out).item() +@pytest.mark.skipif( + not is_ROCm_mx_supported(), + reason="AMD mxfloat8 test requires ROCm 7.0 on gfx950 GPU", +) +def test_hipblaslt_fp8(): + """Test HIPBLASLT backend for FP8 operations""" + a = torch.randn(128, 128, device="cuda") + b = torch.randn(128, 128, device="cuda") + + a_mx = MXTensor.to_mx( + a, torch.float8_e4m3fn, gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT + ) + b_mx = MXTensor.to_mx( + b, torch.float8_e4m3fn, gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT + ) + + # Compute reference result in high precision + out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose( + -1, -2 + ) + + # Compute result using HIPBLASLT backend with scaled_mm + out = torch._scaled_mm( + a_mx._data, + b_mx._data.transpose(-1, -2), + a_mx._scale_e8m0.view(torch.float8_e8m0fnu), + b_mx._scale_e8m0.view(torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ) + + # Verify results TODO: ROCm specific threshold + sqnr = compute_error(out_hp, out).item() + assert sqnr > 80.0, f"SQNR {sqnr} below threshold 80.0" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 587d81f6a6..8a38a71365 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -1,7 +1,8 @@ # MX training and inference with native PyTorch -This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) -in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware. +This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) +in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 and AMD MI355x hardware. + ## Overall status @@ -29,6 +30,9 @@ from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice gemm_kernel_choice = MXGemmKernelChoice.CUBLAS # gemm_kernel_choice = MXGemmKernelChoice.CUTLASS +# on AMD MI355x GPUs with ROCm 6.5+ and gfx950, you can use HIPBLASLT mxfp8 kernels +gemm_kernel_choice = MXGemmKernelChoice.HIPBLASLT + # on older NVIDIA gpus, you can run training with emulated MX gemm # gemm_kernel_choice = MXGemmKernelChoice.EMULATED @@ -97,6 +101,8 @@ on supported hardware, you can run the following command: // example output: https://gist.github.com/vkuzo/a1ddb782e6e1c2aef0c726b3df99efbc ``` +On AMD MI355x GPUs with ROCm 6.5+ and gfx950, we use HIPBLASLT for mxfp8 gemm. We are actively working on optimizing the end-to-end performance for AMD hardware. + ## to_mx cast across dim0 and dim1 On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 5.5 TB/s** for the dim0 cast (with torch.compile), diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index eb1b15228d..0ebe5702a7 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -32,11 +32,15 @@ class MXGemmKernelChoice(Enum): # note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873 CUBLAS = "cublas" + # available only on ROCm with HIPBLASLT support, require gfx950 and ROCm 7.0 + HIPBLASLT = "hipblaslt" + # Pre-made recipes for common configurations class MXLinearRecipeName(Enum): MXFP8_EMULATED = "mxfp8_emulated" MXFP8_CUBLAS = "mxfp8_cublas" + MXFP8_HIPBLASLT = "mxfp8_hipblaslt" MXFP4_EMULATED = "mxfp4_emulated" MXFP4_CUTLASS = "mxfp4_cutlass" @@ -64,6 +68,15 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): assert elem_dtype in valid_dtypes, ( f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" ) + elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT: + assert block_size == 32, ( + f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}" + ) + valid_dtypes = [torch.float8_e4m3fn] + assert elem_dtype in valid_dtypes, ( + f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}" + ) + assert torch.version.hip is not None, "HIPBLASLT requires ROCm" @dataclass @@ -124,6 +137,8 @@ def from_recipe_name( return MXLinearConfig() elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS: return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS) + elif recipe_name is MXLinearRecipeName.MXFP8_HIPBLASLT: + return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT) elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED: return MXLinearConfig(elem_dtype=torch.float4_e2m1fn_x2) elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS: diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index c7e673dc37..4f460d1d03 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -88,7 +88,11 @@ def _addmm_mx_dispatch( """ gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice) - if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): + if gemm_choice in ( + MXGemmKernelChoice.CUBLAS, + MXGemmKernelChoice.CUTLASS, + MXGemmKernelChoice.HIPBLASLT, + ): # real MX gemm backed by torchao's CUTLASS kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] assert a._data.is_contiguous() @@ -103,8 +107,11 @@ def _addmm_mx_dispatch( if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn - assert gemm_choice is MXGemmKernelChoice.CUBLAS, ( - "CUBLAS is the only supported kernel choice for MX FP8 operations" + assert gemm_choice in ( + MXGemmKernelChoice.CUBLAS, + MXGemmKernelChoice.HIPBLASLT, + ), ( + "CUBLAS and HIPBLASLT are the only supported kernel choices for MX FP8 operations ATM" ) res = torch._scaled_mm( diff --git a/torchao/utils.py b/torchao/utils.py index 416d23d785..cce9ee5e4f 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -710,3 +710,20 @@ def is_package_at_least(package_name: str, min_version: str): return False return version(package_name) >= min_version + + +def is_ROCm_mx_supported() -> bool: + """ + Check if the current environment supports ROCm MX operations. + This requires: + 1. ROCm platform + 2. gfx950 GPU (MI350) + 3. ROCm 7.0 + """ + return all( + [ + is_ROCM(), + is_MI350(), + torch.version.hip is not None and torch.version.hip.startswith("7.0"), + ] + )