From c21d24ce8f2d3a42e02a84a1ffcc0524fc548e8a Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 16 Apr 2025 15:54:23 -0700 Subject: [PATCH 01/15] Enhance MX formats to support HIPBLASLT kernel choice and update validation logic. Added MXFP8_HIPBLASLT recipe and adjusted mx_mm function to accommodate new kernel options. --- torchao/prototype/mx_formats/config.py | 15 +++++++++++++++ torchao/prototype/mx_formats/mx_ops.py | 23 ++++++++++++++++++----- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index e1599cfad5..6d38362899 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -33,12 +33,16 @@ class MXGemmKernelChoice(Enum): # note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873 CUBLAS = "cublas" + # available only on ROCm with HIPBLASLT support + HIPBLASLT = "hipblaslt" + # Pre-made recipes for common configurations class MXLinearRecipeName(Enum): MXFP8_EMULATED = "mxfp8_emulated" MXFP8_CUBLAS = "mxfp8_cublas" MXFP8_CUTLASS = "mxfp8_cutlass" + MXFP8_HIPBLASLT = "mxfp8_hipblaslt" MXFP4_EMULATED = "mxfp4_emulated" MXFP4_CUTLASS = "mxfp4_cutlass" @@ -66,6 +70,15 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): assert ( elem_dtype in valid_dtypes ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" + elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT: + assert ( + block_size == 32 + ), f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}" + valid_dtypes = [torch.float8_e4m3fn] + assert ( + elem_dtype in valid_dtypes + ), f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}" + assert torch.version.hip is not None, "HIPBLASLT requires ROCm" @dataclass @@ -128,6 +141,8 @@ def from_recipe_name( return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS) elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS: return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS) + elif recipe_name is MXLinearRecipeName.MXFP8_HIPBLASLT: + return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT) elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED: return MXLinearConfig(elem_dtype=DTYPE_FP4) elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS: diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index c5d60a33de..f08fa9ee28 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -75,8 +75,14 @@ def mx_mm(aten_op, args, kwargs=None): b = args[1] assert isinstance(a, MXTensor) and isinstance(b, MXTensor) assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported" - if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): - # real MX gemm backed by torchao's CUTLASS kernels + kernel_choice = a._gemm_kernel_choice + valid_kernels = ( + MXGemmKernelChoice.CUBLAS, + MXGemmKernelChoice.CUTLASS, + MXGemmKernelChoice.HIPBLASLT, + ) + if kernel_choice in valid_kernels: + # real MX gemm backed by torchao's CUTLASS/CUBLAS/HIPBLASLT kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] assert a._data.is_contiguous() assert b._data.t().is_contiguous() @@ -88,7 +94,12 @@ def mx_mm(aten_op, args, kwargs=None): b_scale_block = to_blocked(b_scale) if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn - if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS: + scaled_mm_kernels = ( + MXGemmKernelChoice.CUBLAS, + MXGemmKernelChoice.HIPBLASLT, + ) + if kernel_choice in scaled_mm_kernels: + # Use native scaled_mm for both CUBLAS and HIPBLASLT res = torch._scaled_mm( a._data, b._data, @@ -103,7 +114,8 @@ def mx_mm(aten_op, args, kwargs=None): else: assert a._elem_dtype == DTYPE_FP4 assert b._elem_dtype == DTYPE_FP4 - assert a._gemm_kernel_choice is MXGemmKernelChoice.CUTLASS, "unsupported" + msg = "FP4 is only supported with CUTLASS kernel at this moment" + assert kernel_choice is MXGemmKernelChoice.CUTLASS, msg res = torchao.ops.mx_fp4_bf16( a._data, b._data, a_scale_block, b_scale_block ) @@ -162,7 +174,8 @@ def mx_view_op(aten_op, args, kwargs=None): if args[0]._elem_dtype == DTYPE_FP4: # special case fp4 as we pack two elements per byte new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) - elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6: + elif (args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and + args[0]._pack_fp6): # special case fp6 as we pack 4 elements in 3 bytes new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) new_data = aten_op(data, new_size, *args[2:], **kwargs) From 36dd5b7a47884c4a862064dca864ca214db89a7d Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 16 Apr 2025 16:11:41 -0700 Subject: [PATCH 02/15] Update README.md to include support for AMD MI355x hardware and HIPBLASLT kernel choice for mxfp8 gemm. Enhance documentation on end-to-end performance optimization efforts for AMD GPUs. --- torchao/prototype/mx_formats/README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 955a02704f..b9224c2594 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -1,7 +1,7 @@ # MX training and inference with native PyTorch This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) -in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware. +in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 and AMD MI355x hardware. ## Overall status @@ -29,6 +29,9 @@ from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice gemm_kernel_choice = MXGemmKernelChoice.CUBLAS # gemm_kernel_choice = MXGemmKernelChoice.CUTLASS +# on AMD MI355x GPUs with ROCm 6.5+ and gfx950, you can use HIPBLASLT mxfp8 kernels +# gemm_kernel_choice = MXGemmKernelChoice.HIPBLASLT + # on older NVIDIA gpus, you can run training with emulated MX gemm # gemm_kernel_choice = MXGemmKernelChoice.EMULATED @@ -97,6 +100,8 @@ on supported hardware, you can run the following command: // example output: https://gist.github.com/vkuzo/a1ddb782e6e1c2aef0c726b3df99efbc ``` +On AMD MI355x GPUs with ROCm 6.5+ and gfx950, we use HIPBLASLT for mxfp8 gemm. We are actively working on optimizing the end-to-end performance for AMD hardware. + ## to_mx cast across dim0 and dim1 On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 5.5 TB/s** for the dim0 cast (with torch.compile), From c75df8e07d380673c1c75b300f3b294ef1452b94 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Fri, 18 Apr 2025 09:41:24 -0700 Subject: [PATCH 03/15] lint --- torchao/prototype/mx_formats/mx_ops.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index f08fa9ee28..cca18f3b89 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -174,8 +174,7 @@ def mx_view_op(aten_op, args, kwargs=None): if args[0]._elem_dtype == DTYPE_FP4: # special case fp4 as we pack two elements per byte new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) - elif (args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and - args[0]._pack_fp6): + elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6: # special case fp6 as we pack 4 elements in 3 bytes new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) new_data = aten_op(data, new_size, *args[2:], **kwargs) @@ -198,9 +197,9 @@ def autocast_to_copy(aten_op, args, kwargs=None): tensor. """ assert isinstance(args[0], MXTensor) - assert ( - len(kwargs) == 1 and "dtype" in kwargs - ), "Only support dtype kwarg for autocast" + assert len(kwargs) == 1 and "dtype" in kwargs, ( + "Only support dtype kwarg for autocast" + ) assert kwargs["dtype"] in { torch.float16, torch.bfloat16, From df2c2203eae3af02efb4963951284471af4ace0c Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 5 May 2025 14:36:58 -0700 Subject: [PATCH 04/15] lint --- torchao/prototype/mx_formats/config.py | 18 +++++++++--------- torchao/prototype/mx_formats/mx_ops.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 3df554e1da..f41aab817a 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -67,17 +67,17 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}" ) valid_dtypes = [torch.float8_e4m3fn] - assert ( - elem_dtype in valid_dtypes - ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" + assert elem_dtype in valid_dtypes, ( + f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" + ) elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT: - assert ( - block_size == 32 - ), f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}" + assert block_size == 32, ( + f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}" + ) valid_dtypes = [torch.float8_e4m3fn] - assert ( - elem_dtype in valid_dtypes - ), f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}" + assert elem_dtype in valid_dtypes, ( + f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}" + ) assert torch.version.hip is not None, "HIPBLASLT requires ROCm" diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 8795364da2..c510dc0c59 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -99,7 +99,7 @@ def mx_mm(aten_op, args, kwargs=None): MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.HIPBLASLT, ) - + assert a._gemm_kernel_choice is scaled_mm_kernels, ( "CUBLAS/HIPBLASLT is the only supported kernel choice for MX FP8 operations atm" ) From 8ae402148859a14b58dc005df8c2aa13fdf6b1fd Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 4 Jun 2025 14:31:37 -0700 Subject: [PATCH 05/15] Update HIPBLASLT comment in config.py and adjust assertion in mx_ops.py to include HIPBLASLT as a valid kernel choice for MX FP8 operations. --- torchao/prototype/mx_formats/config.py | 2 +- torchao/prototype/mx_formats/mx_ops.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 766ff1556f..6d60edcd38 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -32,7 +32,7 @@ class MXGemmKernelChoice(Enum): # note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873 CUBLAS = "cublas" - # available only on ROCm with HIPBLASLT support + # available only on ROCm with HIPBLASLT support, reuqire gfx950 and ROCm 7.0 HIPBLASLT = "hipblaslt" diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index ad1f632d66..b78103178a 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -104,8 +104,7 @@ def _addmm_mx_dispatch( if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn -# TODO : MXGemmKernelChoice.HIPBLASLT check - assert gemm_choice is MXGemmKernelChoice.CUBLAS, ( + assert gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.HIPBLASLT), ( "CUBLAS is the only supported kernel choice for MX FP8 operations" ) From 8505860243dda7928eac012821026bcd893baf33 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 4 Jun 2025 14:33:33 -0700 Subject: [PATCH 06/15] lint --- torchao/prototype/mx_formats/mx_ops.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index b78103178a..fca6fc25e7 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -63,7 +63,6 @@ def _(func, types, args, kwargs): ) - def _get_gemm_choice( choice_a: Optional[MXGemmKernelChoice], choice_b: Optional[MXGemmKernelChoice] ) -> MXGemmKernelChoice: @@ -89,7 +88,11 @@ def _addmm_mx_dispatch( """ gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice) - if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS, MXGemmKernelChoice.HIPBLASLT): + if gemm_choice in ( + MXGemmKernelChoice.CUBLAS, + MXGemmKernelChoice.CUTLASS, + MXGemmKernelChoice.HIPBLASLT, + ): # real MX gemm backed by torchao's CUTLASS kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] assert a._data.is_contiguous() @@ -104,9 +107,10 @@ def _addmm_mx_dispatch( if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn - assert gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.HIPBLASLT), ( - "CUBLAS is the only supported kernel choice for MX FP8 operations" - ) + assert gemm_choice in ( + MXGemmKernelChoice.CUBLAS, + MXGemmKernelChoice.HIPBLASLT, + ), "CUBLAS is the only supported kernel choice for MX FP8 operations" res = torch._scaled_mm( a._data, From 129a6d6b5693474ea76f2c88f5e71e55857cd006 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Wed, 4 Jun 2025 14:56:25 -0700 Subject: [PATCH 07/15] lint --- torchao/prototype/mx_formats/README.md | 2 +- torchao/prototype/mx_formats/config.py | 3 --- torchao/prototype/mx_formats/mx_ops.py | 1 - 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 95c724d80c..8a38a71365 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -31,7 +31,7 @@ gemm_kernel_choice = MXGemmKernelChoice.CUBLAS # gemm_kernel_choice = MXGemmKernelChoice.CUTLASS # on AMD MI355x GPUs with ROCm 6.5+ and gfx950, you can use HIPBLASLT mxfp8 kernels -# gemm_kernel_choice = MXGemmKernelChoice.HIPBLASLT +gemm_kernel_choice = MXGemmKernelChoice.HIPBLASLT # on older NVIDIA gpus, you can run training with emulated MX gemm # gemm_kernel_choice = MXGemmKernelChoice.EMULATED diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 6d60edcd38..48c97c79a1 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -40,7 +40,6 @@ class MXGemmKernelChoice(Enum): class MXLinearRecipeName(Enum): MXFP8_EMULATED = "mxfp8_emulated" MXFP8_CUBLAS = "mxfp8_cublas" - MXFP8_CUTLASS = "mxfp8_cutlass" MXFP8_HIPBLASLT = "mxfp8_hipblaslt" MXFP4_EMULATED = "mxfp4_emulated" MXFP4_CUTLASS = "mxfp4_cutlass" @@ -138,8 +137,6 @@ def from_recipe_name( return MXLinearConfig() elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS: return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS) - elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS: - return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS) elif recipe_name is MXLinearRecipeName.MXFP8_HIPBLASLT: return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT) elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED: diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index fca6fc25e7..fe46ed7367 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -120,7 +120,6 @@ def _addmm_mx_dispatch( bias=bias, out_dtype=torch.bfloat16, ) - else: assert a._elem_dtype == torch.float4_e2m1fn_x2 assert b._elem_dtype == torch.float4_e2m1fn_x2 From c807d707aa0f4558bad8f0570ffc679b07066a82 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 9 Jun 2025 10:56:09 -0700 Subject: [PATCH 08/15] Update assertion message in mx_ops.py to clarify that both CUBLAS and HIPBLASLT are supported kernel choices for MX FP8 operations. --- torchao/prototype/mx_formats/mx_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index fe46ed7367..23432c1348 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -110,7 +110,7 @@ def _addmm_mx_dispatch( assert gemm_choice in ( MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.HIPBLASLT, - ), "CUBLAS is the only supported kernel choice for MX FP8 operations" + ), "CUBLAS and HIPBLASLT are the only supported kernel choices for MX FP8 operations ATM" res = torch._scaled_mm( a._data, From 75db95ec559d557640686cd60bf1fdf12d1a1c9e Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 9 Jun 2025 10:57:31 -0700 Subject: [PATCH 09/15] Refactor assertion in mx_ops.py to improve clarity on supported kernel choices for MX FP8 operations. --- torchao/prototype/mx_formats/mx_ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 23432c1348..4f460d1d03 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -110,7 +110,9 @@ def _addmm_mx_dispatch( assert gemm_choice in ( MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.HIPBLASLT, - ), "CUBLAS and HIPBLASLT are the only supported kernel choices for MX FP8 operations ATM" + ), ( + "CUBLAS and HIPBLASLT are the only supported kernel choices for MX FP8 operations ATM" + ) res = torch._scaled_mm( a._data, From 3ecc91ef472f62b724730260583f8a20db3299eb Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Mon, 9 Jun 2025 12:04:51 -0700 Subject: [PATCH 10/15] Update torchao/prototype/mx_formats/config.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- torchao/prototype/mx_formats/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 48c97c79a1..0ebe5702a7 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -32,7 +32,7 @@ class MXGemmKernelChoice(Enum): # note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873 CUBLAS = "cublas" - # available only on ROCm with HIPBLASLT support, reuqire gfx950 and ROCm 7.0 + # available only on ROCm with HIPBLASLT support, require gfx950 and ROCm 7.0 HIPBLASLT = "hipblaslt" From ef52979dbccdc74a33608033b194114b81054b24 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 9 Jun 2025 13:34:31 -0700 Subject: [PATCH 11/15] Add ROCm MX support check and implement HIPBLASLT FP8 test - Introduced `is_ROCm_mx_supported` function to verify ROCm environment compatibility for MX operations. - Added `test_hipblaslt_fp8` to validate FP8 operations using the HIPBLASLT backend, including SQNR verification for output accuracy. - Updated imports in `test_mx_mm.py` to include necessary utilities for the new test. --- test/prototype/mx_formats/test_mx_mm.py | 38 +++++++++++++++++++++++++ torchao/utils.py | 11 +++++++ 2 files changed, 49 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py index 46380cfb55..0cbb448e09 100644 --- a/test/prototype/mx_formats/test_mx_mm.py +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -10,10 +10,13 @@ from torchao.float8.float8_utils import compute_error from torchao.ops import mx_fp4_bf16 +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.prototype.mx_formats.utils import to_blocked +from torchao.testing.utils import compute_sqnr from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, + is_ROCm_mx_supported, is_sm_at_least_100, ) @@ -57,6 +60,41 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float: return compute_error(out_hp, out).item() +@pytest.mark.skipif( + not is_ROCm_mx_supported(), + reason="AMD mxfloat8 test requires ROCm 7.0 on gfx950 GPU", +) +def test_hipblaslt_fp8(): + """Test HIPBLASLT backend for FP8 operations""" + a = torch.randn(128, 128, device="cuda") + b = torch.randn(128, 128, device="cuda") + + a_mx = MXTensor.to_mx( + a, torch.float8_e4m3fn, gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT + ) + b_mx = MXTensor.to_mx( + b, torch.float8_e4m3fn, gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT + ) + + # Compute reference result in high precision + out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose( + -1, -2 + ) + + # Compute result using HIPBLASLT backend with scaled_mm + out = torch._scaled_mm( + a_mx._data, + b_mx._data.transpose(-1, -2), + a_mx._scale_e8m0.view(torch.float8_e8m0fnu), + b_mx._scale_e8m0.view(torch.float8_e8m0fnu), + out_dtype=torch.bfloat16, + ) + + # Verify results TODO: ROCm specificthreshold + sqnr = compute_sqnr(out_hp, out) + assert sqnr > 80.0, f"SQNR {sqnr} below threshold 80.0" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" diff --git a/torchao/utils.py b/torchao/utils.py index 416d23d785..a5a92afbac 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -710,3 +710,14 @@ def is_package_at_least(package_name: str, min_version: str): return False return version(package_name) >= min_version + + +def is_ROCm_mx_supported() -> bool: + """ + Check if the current environment supports ROCm MX operations. + This requires: + 1. ROCm platform + 2. gfx950 GPU (MI350) + 3. ROCm 7.0 + """ + return all([is_ROCM(), is_MI350(), torch.version.hip.startswith("7.0")]) From f88f1cf00ead18053e974929091eb372ddab981a Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 9 Jun 2025 13:40:38 -0700 Subject: [PATCH 12/15] add space --- test/prototype/mx_formats/test_mx_mm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py index 0cbb448e09..5e0b008cb3 100644 --- a/test/prototype/mx_formats/test_mx_mm.py +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -90,7 +90,7 @@ def test_hipblaslt_fp8(): out_dtype=torch.bfloat16, ) - # Verify results TODO: ROCm specificthreshold + # Verify results TODO: ROCm specific threshold sqnr = compute_sqnr(out_hp, out) assert sqnr > 80.0, f"SQNR {sqnr} below threshold 80.0" From 5d2b55db3dcc75c71ce16d37c3d18086a4b10ebe Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 9 Jun 2025 13:49:58 -0700 Subject: [PATCH 13/15] Refactor SQNR calculation in HIPBLASLT FP8 test - Replaced `compute_sqnr` with `compute_error` for improved accuracy in error measurement. - Updated assertion to ensure output accuracy meets the specified threshold. --- test/prototype/mx_formats/test_mx_mm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py index 5e0b008cb3..a7600c12ea 100644 --- a/test/prototype/mx_formats/test_mx_mm.py +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -13,7 +13,6 @@ from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.prototype.mx_formats.utils import to_blocked -from torchao.testing.utils import compute_sqnr from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, is_ROCm_mx_supported, @@ -91,7 +90,7 @@ def test_hipblaslt_fp8(): ) # Verify results TODO: ROCm specific threshold - sqnr = compute_sqnr(out_hp, out) + sqnr = compute_error(out_hp, out).item() assert sqnr > 80.0, f"SQNR {sqnr} below threshold 80.0" From 979893a9d26db7c1ff0924d9bbd58ac0588d1a73 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 9 Jun 2025 15:42:37 -0700 Subject: [PATCH 14/15] Enhance ROCm MX support check in `is_ROCm_mx_supported` function - Updated the function to ensure `torch.version.hip` is not None before checking the version, improving robustness against potential NoneType errors. --- torchao/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchao/utils.py b/torchao/utils.py index a5a92afbac..82d208a734 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -720,4 +720,8 @@ def is_ROCm_mx_supported() -> bool: 2. gfx950 GPU (MI350) 3. ROCm 7.0 """ - return all([is_ROCM(), is_MI350(), torch.version.hip.startswith("7.0")]) + return all([ + is_ROCM(), + is_MI350(), + torch.version.hip is not None and torch.version.hip.startswith("7.0") + ]) From 012f938bb66d0c7367f129156d1aa6f5543e952e Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 9 Jun 2025 15:43:55 -0700 Subject: [PATCH 15/15] Refactor `is_ROCm_mx_supported` function for improved readability - Reformatted the return statement to enhance clarity and maintainability of the code. --- torchao/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchao/utils.py b/torchao/utils.py index 82d208a734..cce9ee5e4f 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -720,8 +720,10 @@ def is_ROCm_mx_supported() -> bool: 2. gfx950 GPU (MI350) 3. ROCm 7.0 """ - return all([ - is_ROCM(), - is_MI350(), - torch.version.hip is not None and torch.version.hip.startswith("7.0") - ]) + return all( + [ + is_ROCM(), + is_MI350(), + torch.version.hip is not None and torch.version.hip.startswith("7.0"), + ] + )