From 3a7764119ba6dc4c60ed013e3f24c2e0b69eedb3 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 20:49:08 -0700 Subject: [PATCH 01/23] Fix ROCm GPU architecture detection in setup.py Update GPU architecture check to use gcnArchName and improve detection of gfx942 support --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 1c708741a4..ead3877ae4 100644 --- a/setup.py +++ b/setup.py @@ -363,8 +363,8 @@ def get_extensions(): # TOOD: Remove this and use what CUDA has once we fix all the builds. if IS_ROCM and use_cuda: # Add ROCm GPU architecture check - gpu_arch = torch.cuda.get_device_properties(0).name - if gpu_arch != "gfx942": + gpu_arch = torch.cuda.get_device_properties(0).gcnArchName + if "gfx942" not in gpu_arch: print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") print( "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" From 76d68bf64c42d7021310f997e8c0c99ffc0be14c Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 20:54:53 -0700 Subject: [PATCH 02/23] Refactor CUDA and ROCm source file handling in setup.py Reorganize source file selection logic for CUDA and ROCm builds, improving conditional handling of GPU sources and CUTLASS kernels. Simplify the source file selection process and improve readability of the build configuration. --- setup.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/setup.py b/setup.py index ead3877ae4..0bbfb42ea6 100644 --- a/setup.py +++ b/setup.py @@ -91,9 +91,9 @@ def __init__(self): default=(self._is_arm64() and self._is_macos()), ) if self.build_cpu_aarch64: - assert ( - self._is_arm64() - ), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine" + assert self._is_arm64(), ( + "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine" + ) # TORCHAO_BUILD_KLEIDIAI is disabled by default for now because # 1) It increases the build time @@ -102,9 +102,9 @@ def __init__(self): "TORCHAO_BUILD_KLEIDIAI", default=False ) if self.build_kleidi_ai: - assert ( - self.build_cpu_aarch64 - ), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set" + assert self.build_cpu_aarch64, ( + "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set" + ) # TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default. self.build_experimental_mps = self._os_bool_var( @@ -113,9 +113,9 @@ def __init__(self): if self.build_experimental_mps: assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS" assert self._is_arm64(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64" - assert ( - torch.mps.is_available() - ), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available" + assert torch.mps.is_available(), ( + "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available" + ) def _is_arm64(self) -> bool: return platform.machine().startswith("arm64") @@ -341,6 +341,7 @@ def get_extensions(): hip_sources = list( glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) ) + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin") hip_sources += list( glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) @@ -349,6 +350,16 @@ def get_extensions(): # Collect CUDA source files if needed if not IS_ROCM and use_cuda: sources += cuda_sources + elif IS_ROCM and use_cuda: + # Add ROCm GPU architecture check + gpu_arch = torch.cuda.get_device_properties(0).gcnArchName + if "gfx942" not in gpu_arch: + print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") + print( + "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" + ) + else: + sources += hip_sources else: # Remove CUTLASS-based kernels from the cuda_sources list. An # assumption is that these files will have "cutlass" in its @@ -360,18 +371,6 @@ def get_extensions(): ) sources = [s for s in sources if s not in cutlass_sources] - # TOOD: Remove this and use what CUDA has once we fix all the builds. - if IS_ROCM and use_cuda: - # Add ROCm GPU architecture check - gpu_arch = torch.cuda.get_device_properties(0).gcnArchName - if "gfx942" not in gpu_arch: - print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") - print( - "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" - ) - else: - sources += hip_sources - ext_modules = [] if len(sources) > 0: ext_modules.append( From 16d22c13d6febbd0509e838bd675dbbd6d6dad2e Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 20:58:47 -0700 Subject: [PATCH 03/23] Improve CUTLASS kernel support detection for non-Windows platforms Modify CUTLASS kernel configuration to explicitly check for non-ROCm platforms when enabling support, ensuring more precise build configuration for different GPU environments. --- setup.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/setup.py b/setup.py index 0bbfb42ea6..8bd667077e 100644 --- a/setup.py +++ b/setup.py @@ -292,20 +292,8 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.append("/DEBUG") - curdir = os.path.dirname(os.path.curdir) - extensions_dir = os.path.join(curdir, "torchao", "csrc") - sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) - - extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - cuda_sources = list( - glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) - ) - - if use_cuda: - sources += cuda_sources - use_cutlass = False - if use_cuda and not IS_WINDOWS: + if use_cuda and not IS_ROCM and not IS_WINDOWS: use_cutlass = True cutlass_dir = os.path.join(third_party_path, "cutlass") cutlass_include_dir = os.path.join(cutlass_dir, "include") From 7481959d17501cf92bbc93fe68a2650f7b50cc98 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 21:03:38 -0700 Subject: [PATCH 04/23] Reorder source file collection in setup.py Move source file collection logic to maintain consistent code organization and improve readability of the build configuration. No functional changes were made to the source file selection process. --- setup.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/setup.py b/setup.py index 8bd667077e..4ef715f0b1 100644 --- a/setup.py +++ b/setup.py @@ -292,25 +292,6 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.append("/DEBUG") - use_cutlass = False - if use_cuda and not IS_ROCM and not IS_WINDOWS: - use_cutlass = True - cutlass_dir = os.path.join(third_party_path, "cutlass") - cutlass_include_dir = os.path.join(cutlass_dir, "include") - cutlass_tools_include_dir = os.path.join( - cutlass_dir, "tools", "util", "include" - ) - cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir) - if use_cutlass: - extra_compile_args["nvcc"].extend( - [ - "-DTORCHAO_USE_CUTLASS", - "-I" + cutlass_include_dir, - "-I" + cutlass_tools_include_dir, - "-I" + cutlass_extensions_include_dir, - ] - ) - # Get base directory and source paths curdir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(curdir, "torchao", "csrc") @@ -335,6 +316,25 @@ def get_extensions(): glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) ) + use_cutlass = False + if use_cuda and not IS_ROCM and not IS_WINDOWS: + use_cutlass = True + cutlass_dir = os.path.join(third_party_path, "cutlass") + cutlass_include_dir = os.path.join(cutlass_dir, "include") + cutlass_tools_include_dir = os.path.join( + cutlass_dir, "tools", "util", "include" + ) + cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir) + if use_cutlass: + extra_compile_args["nvcc"].extend( + [ + "-DTORCHAO_USE_CUTLASS", + "-I" + cutlass_include_dir, + "-I" + cutlass_tools_include_dir, + "-I" + cutlass_extensions_include_dir, + ] + ) + # Collect CUDA source files if needed if not IS_ROCM and use_cuda: sources += cuda_sources From 94d1fb467d2ac7309e1c375799a915d1d2f5790d Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 21:11:15 -0700 Subject: [PATCH 05/23] Remove redundant NVCC compilation flag in setup.py Remove the `-t=0` flag from NVCC compilation options, which appears to be unnecessary. This simplifies the compilation configuration without impacting build behavior. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4ef715f0b1..f79d731209 100644 --- a/setup.py +++ b/setup.py @@ -269,7 +269,7 @@ def get_extensions(): extra_link_args = [] extra_compile_args = { "cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"], - "nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"], + "nvcc": ["-O3" if not debug_mode else "-O0", "-std=c++17"], } if not IS_WINDOWS: From 72c2642a6636fd3d61c10c6de50eeaa5c07665d3 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 22:36:46 -0700 Subject: [PATCH 06/23] Add ROCm-specific inline assembly for sparse Marlin MMA operations Add conditional compilation for ROCm platforms in the sparse Marlin matrix multiply accumulate (MMA) function. This ensures proper inline assembly implementation for both CUDA and ROCm environments, using platform-specific register and instruction handling. --- torchao/csrc/cuda/sparse_marlin/mma.h | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index 9e9a9be519..d1918f1898 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -53,6 +53,22 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, float* c = reinterpret_cast(&frag_c); if (psel == 0) { + #ifdef USE_ROCM + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), + "r"(b[2]), "r"(b[4]), "r"(b[6]), "v"(c[0]), "v"(c[1]), + "v"(c[2]), "v"(c[3]), "r"(e[0])); + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), + "r"(b[3]), "r"(b[5]), "r"(b[7]), "v"(c[4]), "v"(c[5]), + "v"(c[6]), "v"(c[7]), "r"(e[0])); + #else asm volatile(MMA_SP_INST "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x0;\n" @@ -67,7 +83,24 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), "r"(e[0])); + #endif } else { + #ifdef USE_ROCM + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), + "r"(b[2]), "r"(b[4]), "r"(b[6]), "v"(c[0]), "v"(c[1]), + "v"(c[2]), "v"(c[3]), "r"(e[0])); + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), + "r"(b[3]), "r"(b[5]), "r"(b[7]), "v"(c[4]), "v"(c[5]), + "v"(c[6]), "v"(c[7]), "r"(e[0])); + #else asm volatile(MMA_SP_INST "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x1;\n" @@ -82,6 +115,7 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), "r"(e[0])); + #endif } } From 75f47874bc29d5c3197255f343402b0319abd480 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 22:41:08 -0700 Subject: [PATCH 07/23] Fix ROCm half-precision conversion in sparse Marlin MMA Use __builtin_bit_cast to correctly convert float pairs to half-precision uint32_t values for AMD GPU platforms, ensuring proper type handling in the sparse Marlin matrix multiply accumulate (MMA) implementation. --- torchao/csrc/cuda/sparse_marlin/mma.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index d1918f1898..6e7f0af09a 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -27,8 +27,8 @@ namespace torchao { // On CUDA earlier than 12.5, the ordered_metadata version of this instruction // is not supported. On later versions of CUDA the version without ordered // metadata results in the following warning: -// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction -// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially +// | Advisory: Modifier 'sp::ordered_metadata' should be used on instruction +// | 'mma' instead of modifier 'sp' as it is expected to have substantially // | reduced performance on some future architectures #if defined(USE_ROCM) @@ -143,8 +143,8 @@ __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, uint2 r; #ifdef USE_ROCM // AMD implementation - r.x = __builtin_amdgcn_cvt_pkrtz(c0, c1); - r.y = __builtin_amdgcn_cvt_pkrtz(c2, c3); + r.x = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(c0, c1)); + r.y = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(c2, c3)); #else // NVIDIA implementation asm("{\n\t" From cf7903976596524692001ce0917b610a6166ffb8 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 22:46:02 -0700 Subject: [PATCH 08/23] Optimize half-precision operations in sparse Marlin MMA Update CUDA half-precision operations using __hsub2 and __hfma2 intrinsics to improve performance and precision in sparse matrix multiply-accumulate (MMA) computations. --- torchao/csrc/cuda/sparse_marlin/mma.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index 6e7f0af09a..de84ebdec2 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -206,8 +206,8 @@ __device__ inline FragB dequant_4bit(int q) { const __half2* MUL_ptr = reinterpret_cast(&MUL); const __half2* ADD_ptr = reinterpret_cast(&ADD); - frag_b[0] = __hsub(*lo_ptr, *SUB_ptr); - frag_b[1] = __hfma(*hi_ptr, *MUL_ptr, *ADD_ptr); + frag_b[0] = __hsub2(*lo_ptr, *SUB_ptr); + frag_b[1] = __hfma2(*hi_ptr, *MUL_ptr, *ADD_ptr); #else // NVIDIA implementation frag_b[0] = __hsub2(*reinterpret_cast(&lo), From a98a427e3cbf9d8206aba1daea6149d2b31131d4 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 22:48:01 -0700 Subject: [PATCH 09/23] Optimize ROCm half-precision operations in sparse Marlin MMA Update AMD GPU implementation to use __hsub2 and __hmul2 intrinsics for improved performance and precision in half-precision sparse matrix multiply-accumulate computations. --- torchao/csrc/cuda/sparse_marlin/mma.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index de84ebdec2..6da6a71ae9 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -240,8 +240,8 @@ __device__ inline FragB dequant_8bit(int q) { __half2* hi_ptr = reinterpret_cast<__half2*>(&hi); const __half2* magic_num_ptr = reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM); - frag_b[0] = __hsub(*lo_ptr, *magic_num_ptr); - frag_b[1] = __hsub(*hi_ptr, *magic_num_ptr); + frag_b[0] = __hsub2(*lo_ptr, *magic_num_ptr); + frag_b[1] = __hsub2(*hi_ptr, *magic_num_ptr); #else // NVIDIA implementation frag_b[0] = __hsub2(*reinterpret_cast(&lo), @@ -258,8 +258,8 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { #ifdef USE_ROCM // AMD implementation __half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul(frag_b[0], s); - frag_b[1] = __hmul(frag_b[1], s); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); #else // NVIDIA implementation half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); From 30bd92480345dad04f60e45d2abf1460cf21219a Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 22:50:55 -0700 Subject: [PATCH 10/23] Fix ROCm float multiplication in sparse Marlin MMA Update AMD GPU implementation to use __builtin_amdgcn_fmul_f32 instead of __builtin_amdgcn_fmul_legacy for more accurate float multiplication in the scale_floats function. --- torchao/csrc/cuda/sparse_marlin/mma.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index 6da6a71ae9..dbdecb4132 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -272,16 +272,16 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, FragS& s0, float* c4, float* c5, float* c6, float* c7, FragS& s1) { #ifdef USE_ROCM - // AMD implementation - *c0 = __builtin_amdgcn_fmul_legacy(*c0, __half2float(s0[0].x)); - *c1 = __builtin_amdgcn_fmul_legacy(*c1, __half2float(s0[0].y)); - *c2 = __builtin_amdgcn_fmul_legacy(*c2, __half2float(s0[1].x)); - *c3 = __builtin_amdgcn_fmul_legacy(*c3, __half2float(s0[1].y)); + // AMD implementation - fixed + *c0 = __builtin_amdgcn_fmul_f32(*c0, __half2float(s0[0].x)); + *c1 = __builtin_amdgcn_fmul_f32(*c1, __half2float(s0[0].y)); + *c2 = __builtin_amdgcn_fmul_f32(*c2, __half2float(s0[1].x)); + *c3 = __builtin_amdgcn_fmul_f32(*c3, __half2float(s0[1].y)); - *c4 = __builtin_amdgcn_fmul_legacy(*c4, __half2float(s1[0].x)); - *c5 = __builtin_amdgcn_fmul_legacy(*c5, __half2float(s1[0].y)); - *c6 = __builtin_amdgcn_fmul_legacy(*c6, __half2float(s1[1].x)); - *c7 = __builtin_amdgcn_fmul_legacy(*c7, __half2float(s1[1].y)); + *c4 = __builtin_amdgcn_fmul_f32(*c4, __half2float(s1[0].x)); + *c5 = __builtin_amdgcn_fmul_f32(*c5, __half2float(s1[0].y)); + *c6 = __builtin_amdgcn_fmul_f32(*c6, __half2float(s1[1].x)); + *c7 = __builtin_amdgcn_fmul_f32(*c7, __half2float(s1[1].y)); #else // NVIDIA implementation *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); From 66691c32bc1bafeaa6e774f9131a2e9caffbcd72 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 23:03:43 -0700 Subject: [PATCH 11/23] Add ROCm header support for sparse Marlin MMA implementation Include necessary ROCm-specific headers for HIP runtime and half-precision operations, with comments addressing potential compiler and architecture considerations for AMD GPU platforms. --- torchao/csrc/cuda/sparse_marlin/mma.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index dbdecb4132..d7ded85550 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -22,13 +22,21 @@ #include #endif +#ifdef USE_ROCM +#include +#include +#include // For some ROCm versions +// Some intrinsics might require the compiler to be in the right mode +// with the correct target architecture flags (-march=gfx942) +#endif + namespace torchao { // On CUDA earlier than 12.5, the ordered_metadata version of this instruction // is not supported. On later versions of CUDA the version without ordered // metadata results in the following warning: -// | Advisory: Modifier 'sp::ordered_metadata' should be used on instruction -// | 'mma' instead of modifier 'sp' as it is expected to have substantially +// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction +// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially // | reduced performance on some future architectures #if defined(USE_ROCM) From 04014e70869e6490281289cb7220c079d3f27505 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Mon, 10 Mar 2025 23:51:00 -0700 Subject: [PATCH 12/23] Update ROCm float multiplication in sparse Marlin MMA Replace __builtin_amdgcn_fmul_f32 with __ocml_fmul_f32 for more accurate and consistent float multiplication in the scale_floats function on AMD GPU platforms. --- torchao/csrc/cuda/sparse_marlin/mma.h | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index d7ded85550..0d4de12993 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -35,8 +35,8 @@ namespace torchao { // On CUDA earlier than 12.5, the ordered_metadata version of this instruction // is not supported. On later versions of CUDA the version without ordered // metadata results in the following warning: -// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction -// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially +// | Advisory: Modifier 'sp::ordered_metadata' should be used on instruction +// | 'mma' instead of modifier 'sp' as it is expected to have substantially // | reduced performance on some future architectures #if defined(USE_ROCM) @@ -281,15 +281,15 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, float* c7, FragS& s1) { #ifdef USE_ROCM // AMD implementation - fixed - *c0 = __builtin_amdgcn_fmul_f32(*c0, __half2float(s0[0].x)); - *c1 = __builtin_amdgcn_fmul_f32(*c1, __half2float(s0[0].y)); - *c2 = __builtin_amdgcn_fmul_f32(*c2, __half2float(s0[1].x)); - *c3 = __builtin_amdgcn_fmul_f32(*c3, __half2float(s0[1].y)); + *c0 = __ocml_fmul_f32(*c0, __half2float(s0[0].x)); + *c1 = __ocml_fmul_f32(*c1, __half2float(s0[0].y)); + *c2 = __ocml_fmul_f32(*c2, __half2float(s0[1].x)); + *c3 = __ocml_fmul_f32(*c3, __half2float(s0[1].y)); - *c4 = __builtin_amdgcn_fmul_f32(*c4, __half2float(s1[0].x)); - *c5 = __builtin_amdgcn_fmul_f32(*c5, __half2float(s1[0].y)); - *c6 = __builtin_amdgcn_fmul_f32(*c6, __half2float(s1[1].x)); - *c7 = __builtin_amdgcn_fmul_f32(*c7, __half2float(s1[1].y)); + *c4 = __ocml_fmul_f32(*c4, __half2float(s1[0].x)); + *c5 = __ocml_fmul_f32(*c5, __half2float(s1[0].y)); + *c6 = __ocml_fmul_f32(*c6, __half2float(s1[1].x)); + *c7 = __ocml_fmul_f32(*c7, __half2float(s1[1].y)); #else // NVIDIA implementation *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); From dc539804ea1079c5483a24ae750910812e3e47e5 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 14:01:36 -0700 Subject: [PATCH 13/23] Optimize ROCm global to LDS transfer in sparse Marlin MMA Replace __builtin_amdgcn_global_load_lds with inline assembly using ds_load_b instruction for more precise and direct global to local data store (LDS) transfer on MI300X AMD GPUs. --- torchao/csrc/cuda/sparse_marlin/mem.h | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 1569e3cdda..3a25180f8f 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -51,7 +51,11 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, int src_in_bytes = (zfill ? 0 : BYTES); uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); + // Use LDS.G instruction for global to LDS transfer on MI300X + asm volatile( + "{\n" + " ds_load_b%c2 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr), "i"(BYTES)); #else asm volatile( "{\n" @@ -68,7 +72,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, const int BYTES = 16; uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); + // Use LDS.G instruction for global to LDS transfer on MI300X + asm volatile( + "{\n" + " ds_load_b%c2 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr), "i"(BYTES)); #else asm volatile( "{\n" @@ -85,7 +93,11 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); + // Use LDS.G instruction for global to LDS transfer on MI300X + asm volatile( + "{\n" + " ds_load_b%c2 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr), "i"(BYTES)); #else asm volatile( "{\n" From 6f43e014371f4a385cbc23a98843d443c4161a88 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 14:07:13 -0700 Subject: [PATCH 14/23] Simplify ROCm float multiplication in sparse Marlin MMA Replace __ocml_fmul_f32 with standard C++ multiplication for more readable and straightforward float scaling on AMD MI300X GPUs. --- torchao/csrc/cuda/sparse_marlin/mma.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index 0d4de12993..8894b3dcce 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -280,16 +280,16 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, FragS& s0, float* c4, float* c5, float* c6, float* c7, FragS& s1) { #ifdef USE_ROCM - // AMD implementation - fixed - *c0 = __ocml_fmul_f32(*c0, __half2float(s0[0].x)); - *c1 = __ocml_fmul_f32(*c1, __half2float(s0[0].y)); - *c2 = __ocml_fmul_f32(*c2, __half2float(s0[1].x)); - *c3 = __ocml_fmul_f32(*c3, __half2float(s0[1].y)); +// AMD MI300X implementation + *c0 = *c0 * __half2float(s0[0].x); + *c1 = *c1 * __half2float(s0[0].y); + *c2 = *c2 * __half2float(s0[1].x); + *c3 = *c3 * __half2float(s0[1].y); - *c4 = __ocml_fmul_f32(*c4, __half2float(s1[0].x)); - *c5 = __ocml_fmul_f32(*c5, __half2float(s1[0].y)); - *c6 = __ocml_fmul_f32(*c6, __half2float(s1[1].x)); - *c7 = __ocml_fmul_f32(*c7, __half2float(s1[1].y)); + *c4 = *c4 * __half2float(s1[0].x); + *c5 = *c5 * __half2float(s1[0].y); + *c6 = *c6 * __half2float(s1[1].x); + *c7 = *c7 * __half2float(s1[1].y); #else // NVIDIA implementation *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); From ed9282dbddc2091de1bae57adf2447cea302bff6 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 14:24:22 -0700 Subject: [PATCH 15/23] Fix CUDA kernel attribute setting in Marlin sparse MMA implementation Update cudaFuncSetAttribute call to use reinterpret_cast for correct function pointer handling in the Marlin_24 CUDA kernel, ensuring proper dynamic shared memory configuration. --- torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu index 4f6980f29a..c583011a07 100644 --- a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu +++ b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu @@ -862,8 +862,8 @@ __global__ void Marlin_24( thread_k_blocks == THREAD_K_BLOCKS && \ group_blocks == GROUP_BLOCKS) { \ cudaFuncSetAttribute( \ - Marlin_24, \ + reinterpret_cast(&Marlin_24), \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ Marlin_24 \ From b5390621568221779ce38c9993ac3d6ad004f2b8 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 14:30:39 -0700 Subject: [PATCH 16/23] Enhance ROCm global to LDS transfer with size-specific load instructions Refactor cp_async4 functions for ROCm to use explicit ds_load instructions for 4, 8, and 16-byte transfers. Add a fallback mechanism using __builtin_memcpy for unsupported sizes, improving the precision and flexibility of global to local data store (LDS) transfers on MI300X AMD GPUs. --- torchao/csrc/cuda/sparse_marlin/mem.h | 72 +++++++++++++++++++++------ 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 3a25180f8f..b24293c14f 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -51,11 +51,25 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, int src_in_bytes = (zfill ? 0 : BYTES); uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - // Use LDS.G instruction for global to LDS transfer on MI300X - asm volatile( - "{\n" - " ds_load_b%c2 %0, %1\n" - "}\n" :: "v"(smem), "v"(glob_ptr), "i"(BYTES)); + // Use appropriate ds_load instruction based on byte size + if (BYTES == 4) { + asm volatile( + "{\n" + " ds_load_b32 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else if (BYTES == 8) { + asm volatile( + "{\n" + " ds_load_b64 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else if (BYTES == 16) { + asm volatile( + "{\n" + " ds_load_b128 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else { + // Fallback for other sizes + __builtin_memcpy(smem, glob_ptr, BYTES); #else asm volatile( "{\n" @@ -72,11 +86,25 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, const int BYTES = 16; uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - // Use LDS.G instruction for global to LDS transfer on MI300X - asm volatile( - "{\n" - " ds_load_b%c2 %0, %1\n" - "}\n" :: "v"(smem), "v"(glob_ptr), "i"(BYTES)); + // Use appropriate ds_load instruction based on byte size + if (BYTES == 4) { + asm volatile( + "{\n" + " ds_load_b32 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else if (BYTES == 8) { + asm volatile( + "{\n" + " ds_load_b64 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else if (BYTES == 16) { + asm volatile( + "{\n" + " ds_load_b128 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else { + // Fallback for other sizes + __builtin_memcpy(smem, glob_ptr, BYTES); #else asm volatile( "{\n" @@ -93,11 +121,25 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - // Use LDS.G instruction for global to LDS transfer on MI300X - asm volatile( - "{\n" - " ds_load_b%c2 %0, %1\n" - "}\n" :: "v"(smem), "v"(glob_ptr), "i"(BYTES)); + // Use appropriate ds_load instruction based on byte size + if (BYTES == 4) { + asm volatile( + "{\n" + " ds_load_b32 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else if (BYTES == 8) { + asm volatile( + "{\n" + " ds_load_b64 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else if (BYTES == 16) { + asm volatile( + "{\n" + " ds_load_b128 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else { + // Fallback for other sizes + __builtin_memcpy(smem, glob_ptr, BYTES); #else asm volatile( "{\n" From c316a98030fe95d01f97e969e2c9f52a0852a3af Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 14:39:48 -0700 Subject: [PATCH 17/23] Fix missing closing braces in ROCm cp_async4 memory transfer functions Add missing closing braces in cp_async4_pred_zfill, cp_async4_pred, and cp_async4 functions to ensure proper code structure and prevent potential compilation issues in the ROCm sparse Marlin MMA implementation. --- torchao/csrc/cuda/sparse_marlin/mem.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index b24293c14f..ddffe3eb52 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -70,6 +70,7 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, } else { // Fallback for other sizes __builtin_memcpy(smem, glob_ptr, BYTES); + } #else asm volatile( "{\n" @@ -105,6 +106,7 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, } else { // Fallback for other sizes __builtin_memcpy(smem, glob_ptr, BYTES); + } #else asm volatile( "{\n" @@ -140,6 +142,7 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { } else { // Fallback for other sizes __builtin_memcpy(smem, glob_ptr, BYTES); + } #else asm volatile( "{\n" From 3a2481fe449e1f1722abe881987d192b4b9c997b Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 14:43:36 -0700 Subject: [PATCH 18/23] Remove unnecessary fallback memcpy in ROCm cp_async4 memory transfer functions Simplify ROCm global to LDS transfer by removing fallback __builtin_memcpy in cp_async4_pred_zfill, cp_async4_pred, and cp_async4 functions, reducing code complexity while maintaining the primary ds_load_b128 transfer mechanism. --- torchao/csrc/cuda/sparse_marlin/mem.h | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index ddffe3eb52..c33e7453f6 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -67,10 +67,7 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, "{\n" " ds_load_b128 %0, %1\n" "}\n" :: "v"(smem), "v"(glob_ptr)); - } else { - // Fallback for other sizes - __builtin_memcpy(smem, glob_ptr, BYTES); - } + } #else asm volatile( "{\n" @@ -103,9 +100,6 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, "{\n" " ds_load_b128 %0, %1\n" "}\n" :: "v"(smem), "v"(glob_ptr)); - } else { - // Fallback for other sizes - __builtin_memcpy(smem, glob_ptr, BYTES); } #else asm volatile( @@ -139,10 +133,7 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { "{\n" " ds_load_b128 %0, %1\n" "}\n" :: "v"(smem), "v"(glob_ptr)); - } else { - // Fallback for other sizes - __builtin_memcpy(smem, glob_ptr, BYTES); - } + } #else asm volatile( "{\n" From 59455ed15af4c671d554f8b9931f6cff13d252ae Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 14:51:18 -0700 Subject: [PATCH 19/23] Remove 16-byte ds_load instruction in ROCm cp_async4 memory transfer functions Simplify ROCm global to LDS transfer by removing the 16-byte ds_load_b128 instruction from cp_async4_pred_zfill, cp_async4_pred, and cp_async4 functions, further reducing code complexity and maintaining the core transfer mechanism. --- torchao/csrc/cuda/sparse_marlin/mem.h | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index c33e7453f6..f3434cd976 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -62,11 +62,6 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, "{\n" " ds_load_b64 %0, %1\n" "}\n" :: "v"(smem), "v"(glob_ptr)); - } else if (BYTES == 16) { - asm volatile( - "{\n" - " ds_load_b128 %0, %1\n" - "}\n" :: "v"(smem), "v"(glob_ptr)); } #else asm volatile( @@ -95,11 +90,6 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, "{\n" " ds_load_b64 %0, %1\n" "}\n" :: "v"(smem), "v"(glob_ptr)); - } else if (BYTES == 16) { - asm volatile( - "{\n" - " ds_load_b128 %0, %1\n" - "}\n" :: "v"(smem), "v"(glob_ptr)); } #else asm volatile( @@ -128,11 +118,6 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { "{\n" " ds_load_b64 %0, %1\n" "}\n" :: "v"(smem), "v"(glob_ptr)); - } else if (BYTES == 16) { - asm volatile( - "{\n" - " ds_load_b128 %0, %1\n" - "}\n" :: "v"(smem), "v"(glob_ptr)); } #else asm volatile( From ca6c6463b8c8b33fbe378264a8a9f9ee2dd6af3e Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 14:57:30 -0700 Subject: [PATCH 20/23] global_load_dwordx4 --- torchao/csrc/cuda/sparse_marlin/mem.h | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index f3434cd976..57fba53125 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -156,11 +156,12 @@ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - asm volatile( - "ds_read_b128 %0, %1 offset:0\n" - "ds_read_b128 %2, %1 offset:16\n" - : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) - : "v"(smem)); + // MI300 specific implementation - try global_load if available + asm volatile( + "global_load_dwordx4 %0, %4\n" + "global_load_dwordx4 %2, %4 offset:16\n" + : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) + : "v"(smem)); #else asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -189,11 +190,12 @@ __device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - asm volatile( - "ds_read_b128 %0, %1 offset:0\n" - "ds_read_b128 %2, %1 offset:16\n" - : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) - : "v"(smem)); + // MI300 specific implementation - try global_load if available + asm volatile( + "global_load_dwordx4 %0, %4\n" + "global_load_dwordx4 %2, %4 offset:16\n" + : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) + : "v"(smem)); #else asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" From d1a9df95af7b13e2d7aa3518f3ac31de8d97075e Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 15:03:21 -0700 Subject: [PATCH 21/23] Improve ROCm memory load instructions in sparse Marlin MMA implementation Replace global_load_dwordx4 with multiple ds_read_b32 instructions for better compatibility and support across different ROCm platforms. Modify ldsm4 and ldsm4_t functions to use more widely supported memory load techniques. --- torchao/csrc/cuda/sparse_marlin/mem.h | 30 ++++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 57fba53125..20ba3b5a92 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -156,11 +156,18 @@ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - // MI300 specific implementation - try global_load if available + // Try using multiple ds_read_b32 instructions which are more widely supported asm volatile( - "global_load_dwordx4 %0, %4\n" - "global_load_dwordx4 %2, %4 offset:16\n" - : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) + "ds_read_b32 %0, %8 offset:0\n" + "ds_read_b32 %1, %8 offset:4\n" + "ds_read_b32 %2, %8 offset:8\n" + "ds_read_b32 %3, %8 offset:12\n" + "ds_read_b32 %4, %8 offset:16\n" + "ds_read_b32 %5, %8 offset:20\n" + "ds_read_b32 %6, %8 offset:24\n" + "ds_read_b32 %7, %8 offset:28\n" + : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]), + "=v"(a[4]), "=v"(a[5]), "=v"(a[6]), "=v"(a[7]) : "v"(smem)); #else asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" @@ -190,11 +197,18 @@ __device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - // MI300 specific implementation - try global_load if available + // Try using multiple ds_read_b32 instructions which are more widely supported asm volatile( - "global_load_dwordx4 %0, %4\n" - "global_load_dwordx4 %2, %4 offset:16\n" - : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) + "ds_read_b32 %0, %8 offset:0\n" + "ds_read_b32 %1, %8 offset:4\n" + "ds_read_b32 %2, %8 offset:8\n" + "ds_read_b32 %3, %8 offset:12\n" + "ds_read_b32 %4, %8 offset:16\n" + "ds_read_b32 %5, %8 offset:20\n" + "ds_read_b32 %6, %8 offset:24\n" + "ds_read_b32 %7, %8 offset:28\n" + : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]), + "=v"(a[4]), "=v"(a[5]), "=v"(a[6]), "=v"(a[7]) : "v"(smem)); #else asm volatile( From 63e8d5ee0ab40920d922aaadcd32dfa84ad97091 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 15:26:23 -0700 Subject: [PATCH 22/23] Refine ROCm memory load instruction in sparse Marlin ldsm4_m function Update ldsm4_m device function to use separate ds_read_b32 instructions instead of a single ds_read_b64, improving compatibility and load behavior on ROCm platforms. --- torchao/csrc/cuda/sparse_marlin/mem.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 20ba3b5a92..b9f3452a82 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -181,7 +181,8 @@ __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM asm volatile( - "ds_read_b64 %0, %2 offset:0\n" + "ds_read_b32 %0, %2 offset:0\n" + "ds_read_b32 %1, %2 offset:4\n" : "=v"(a[0]), "=v"(a[1]) : "v"(smem)); #else From 3e5a411abc4af93cf5d7974c78b24c33147bd826 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Tue, 11 Mar 2025 15:34:38 -0700 Subject: [PATCH 23/23] Update ROCm MFMA instruction syntax in sparse Marlin MMA implementation Modify the MFMA instruction assembly for AMD GPUs to use correct syntax and operand handling. Replace register constraints with vector register constraints and simplify the instruction format to improve compatibility and readability on ROCm platforms. --- torchao/csrc/cuda/sparse_marlin/mma.h | 43 +++++++++++++-------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index 8894b3dcce..8ba3d22669 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -40,8 +40,8 @@ namespace torchao { // | reduced performance on some future architectures #if defined(USE_ROCM) - // HIP ISA doesn't have an equivalent for ordered_metadata, so we'll use the standard mma instruction - #define MMA_SP_INST "v_mfma_f32_16x16x16f16 " + // Correct MFMA instruction for AMD GPUs + #define MMA_SP_INST "v_mfma_f32_16x16x16_f16 " #elif defined(CUDA_VERSION) && CUDA_VERSION >= 12050 #define MMA_SP_INST \ "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " @@ -62,20 +62,21 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, float* c = reinterpret_cast(&frag_c); if (psel == 0) { #ifdef USE_ROCM + // AMD GPUs use a different syntax for MFMA instructions + // The operands need to be listed individually, not in curly braces asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" + "%0, %4, %8, %12\n" : "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "v"(c[0]), "v"(c[1]), - "v"(c[2]), "v"(c[3]), "r"(e[0])); + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), + "v"(b[0]), "v"(b[2]), "v"(b[4]), "v"(b[6]), + "v"(c[0]), "v"(c[1]), "v"(c[2]), "v"(c[3])); + asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" + "%0, %4, %8, %12\n" : "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "v"(c[4]), "v"(c[5]), - "v"(c[6]), "v"(c[7]), "r"(e[0])); + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), + "v"(b[1]), "v"(b[3]), "v"(b[5]), "v"(b[7]), + "v"(c[4]), "v"(c[5]), "v"(c[6]), "v"(c[7])); #else asm volatile(MMA_SP_INST "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " @@ -95,19 +96,17 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, } else { #ifdef USE_ROCM asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" + "%0, %4, %8, %12\n" : "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "v"(c[0]), "v"(c[1]), - "v"(c[2]), "v"(c[3]), "r"(e[0])); + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), + "v"(b[0]), "v"(b[2]), "v"(b[4]), "v"(b[6]), + "v"(c[0]), "v"(c[1]), "v"(c[2]), "v"(c[3])); asm volatile(MMA_SP_INST - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" + "%0, %4, %8, %12\n" : "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "v"(c[4]), "v"(c[5]), - "v"(c[6]), "v"(c[7]), "r"(e[0])); + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), + "v"(b[1]), "v"(b[3]), "v"(b[5]), "v"(b[7]), + "v"(c[4]), "v"(c[5]), "v"(c[6]), "v"(c[7])); #else asm volatile(MMA_SP_INST "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "