diff --git a/setup.py b/setup.py index 0b5de7e855..075e241e41 100644 --- a/setup.py +++ b/setup.py @@ -97,9 +97,9 @@ def __init__(self): default=(self._is_arm64() and self._is_macos()), ) if self.build_cpu_aarch64: - assert ( - self._is_arm64() - ), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine" + assert self._is_arm64(), ( + "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine" + ) # TORCHAO_BUILD_KLEIDIAI is disabled by default for now because # 1) It increases the build time @@ -108,9 +108,9 @@ def __init__(self): "TORCHAO_BUILD_KLEIDIAI", default=False ) if self.build_kleidi_ai: - assert ( - self.build_cpu_aarch64 - ), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set" + assert self.build_cpu_aarch64, ( + "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set" + ) # TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default. self.build_experimental_mps = self._os_bool_var( @@ -119,9 +119,9 @@ def __init__(self): if self.build_experimental_mps: assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS" assert self._is_arm64(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64" - assert ( - torch.mps.is_available() - ), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available" + assert torch.mps.is_available(), ( + "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available" + ) def _is_arm64(self) -> bool: return platform.machine().startswith("arm64") @@ -338,6 +338,7 @@ def get_extensions(): hip_sources = list( glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) ) + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin") hip_sources += list( glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) @@ -350,7 +351,7 @@ def get_extensions(): # TOOD: Remove this and use what CUDA has once we fix all the builds. if IS_ROCM and use_cuda: # Add ROCm GPU architecture check - gpu_arch = torch.cuda.get_device_properties(0).name + gpu_arch = torch.cuda.get_device_properties(0).name.gcnArchName if gpu_arch != "gfx942": print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") print( diff --git a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu index bd64930c4b..8e3cc1fff1 100644 --- a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu +++ b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu @@ -867,8 +867,8 @@ __global__ void Marlin_24( thread_k_blocks == THREAD_K_BLOCKS && \ group_blocks == GROUP_BLOCKS) { \ cudaFuncSetAttribute( \ - Marlin_24, \ + reinterpret_cast(&Marlin_24), \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ Marlin_24 \ diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 3043ec6435..d2bb105578 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -56,7 +56,18 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr, int src_in_bytes = (zfill ? 0 : BYTES); uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); + // Use appropriate ds_load instruction based on byte size + if (BYTES == 4) { + asm volatile( + "{\n" + " ds_load_b32 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else if (BYTES == 8) { + asm volatile( + "{\n" + " ds_load_b64 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } #else asm volatile( "{\n" @@ -73,7 +84,18 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, const int BYTES = 16; uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); + // Use appropriate ds_load instruction based on byte size + if (BYTES == 4) { + asm volatile( + "{\n" + " ds_load_b32 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else if (BYTES == 8) { + asm volatile( + "{\n" + " ds_load_b64 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } #else asm volatile( "{\n" @@ -90,7 +112,18 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); + // Use appropriate ds_load instruction based on byte size + if (BYTES == 4) { + asm volatile( + "{\n" + " ds_load_b32 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } else if (BYTES == 8) { + asm volatile( + "{\n" + " ds_load_b64 %0, %1\n" + "}\n" :: "v"(smem), "v"(glob_ptr)); + } #else asm volatile( "{\n" @@ -128,11 +161,19 @@ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - asm volatile( - "ds_read_b128 %0, %1 offset:0\n" - "ds_read_b128 %2, %1 offset:16\n" - : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) - : "v"(smem)); + // Try using multiple ds_read_b32 instructions which are more widely supported + asm volatile( + "ds_read_b32 %0, %8 offset:0\n" + "ds_read_b32 %1, %8 offset:4\n" + "ds_read_b32 %2, %8 offset:8\n" + "ds_read_b32 %3, %8 offset:12\n" + "ds_read_b32 %4, %8 offset:16\n" + "ds_read_b32 %5, %8 offset:20\n" + "ds_read_b32 %6, %8 offset:24\n" + "ds_read_b32 %7, %8 offset:28\n" + : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]), + "=v"(a[4]), "=v"(a[5]), "=v"(a[6]), "=v"(a[7]) + : "v"(smem)); #else asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -145,7 +186,8 @@ __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM asm volatile( - "ds_read_b64 %0, %2 offset:0\n" + "ds_read_b32 %0, %2 offset:0\n" + "ds_read_b32 %1, %2 offset:4\n" : "=v"(a[0]), "=v"(a[1]) : "v"(smem)); #else @@ -161,11 +203,19 @@ __device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = cvta_to_shared(smem_ptr); #ifdef USE_ROCM - asm volatile( - "ds_read_b128 %0, %1 offset:0\n" - "ds_read_b128 %2, %1 offset:16\n" - : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) - : "v"(smem)); + // Try using multiple ds_read_b32 instructions which are more widely supported + asm volatile( + "ds_read_b32 %0, %8 offset:0\n" + "ds_read_b32 %1, %8 offset:4\n" + "ds_read_b32 %2, %8 offset:8\n" + "ds_read_b32 %3, %8 offset:12\n" + "ds_read_b32 %4, %8 offset:16\n" + "ds_read_b32 %5, %8 offset:20\n" + "ds_read_b32 %6, %8 offset:24\n" + "ds_read_b32 %7, %8 offset:28\n" + : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]), + "=v"(a[4]), "=v"(a[5]), "=v"(a[6]), "=v"(a[7]) + : "v"(smem)); #else asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index 40f2dc30e1..a635a95c14 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -27,18 +27,26 @@ #include #endif +#ifdef USE_ROCM +#include +#include +#include // For some ROCm versions +// Some intrinsics might require the compiler to be in the right mode +// with the correct target architecture flags (-march=gfx942) +#endif + namespace torchao { // On CUDA earlier than 12.5, the ordered_metadata version of this instruction // is not supported. On later versions of CUDA the version without ordered // metadata results in the following warning: -// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction -// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially +// | Advisory: Modifier 'sp::ordered_metadata' should be used on instruction +// | 'mma' instead of modifier 'sp' as it is expected to have substantially // | reduced performance on some future architectures #if defined(USE_ROCM) - // HIP ISA doesn't have an equivalent for ordered_metadata, so we'll use the standard mma instruction - #define MMA_SP_INST "v_mfma_f32_16x16x16f16 " + // Correct MFMA instruction for AMD GPUs + #define MMA_SP_INST "v_mfma_f32_16x16x16_f16 " #elif defined(CUDA_VERSION) && CUDA_VERSION >= 12050 #define MMA_SP_INST \ "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " @@ -58,6 +66,23 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, float* c = reinterpret_cast(&frag_c); if (psel == 0) { + #ifdef USE_ROCM + // AMD GPUs use a different syntax for MFMA instructions + // The operands need to be listed individually, not in curly braces + asm volatile(MMA_SP_INST + "%0, %4, %8, %12\n" + : "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), + "v"(b[0]), "v"(b[2]), "v"(b[4]), "v"(b[6]), + "v"(c[0]), "v"(c[1]), "v"(c[2]), "v"(c[3])); + + asm volatile(MMA_SP_INST + "%0, %4, %8, %12\n" + : "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7]) + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), + "v"(b[1]), "v"(b[3]), "v"(b[5]), "v"(b[7]), + "v"(c[4]), "v"(c[5]), "v"(c[6]), "v"(c[7])); + #else asm volatile(MMA_SP_INST "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x0;\n" @@ -72,7 +97,22 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), "r"(e[0])); + #endif } else { + #ifdef USE_ROCM + asm volatile(MMA_SP_INST + "%0, %4, %8, %12\n" + : "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), + "v"(b[0]), "v"(b[2]), "v"(b[4]), "v"(b[6]), + "v"(c[0]), "v"(c[1]), "v"(c[2]), "v"(c[3])); + asm volatile(MMA_SP_INST + "%0, %4, %8, %12\n" + : "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7]) + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), + "v"(b[1]), "v"(b[3]), "v"(b[5]), "v"(b[7]), + "v"(c[4]), "v"(c[5]), "v"(c[6]), "v"(c[7])); + #else asm volatile(MMA_SP_INST "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x1;\n" @@ -87,6 +127,7 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), "r"(e[0])); + #endif } } @@ -114,8 +155,8 @@ __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, uint2 r; #ifdef USE_ROCM // AMD implementation - r.x = __builtin_amdgcn_cvt_pkrtz(c0, c1); - r.y = __builtin_amdgcn_cvt_pkrtz(c2, c3); + r.x = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(c0, c1)); + r.y = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(c2, c3)); #else // NVIDIA implementation asm("{\n\t" @@ -177,8 +218,8 @@ __device__ inline FragB dequant_4bit(int q) { const __half2* MUL_ptr = reinterpret_cast(&MUL); const __half2* ADD_ptr = reinterpret_cast(&ADD); - frag_b[0] = __hsub(*lo_ptr, *SUB_ptr); - frag_b[1] = __hfma(*hi_ptr, *MUL_ptr, *ADD_ptr); + frag_b[0] = __hsub2(*lo_ptr, *SUB_ptr); + frag_b[1] = __hfma2(*hi_ptr, *MUL_ptr, *ADD_ptr); #else // NVIDIA implementation frag_b[0] = __hsub2(*reinterpret_cast(&lo), @@ -211,8 +252,8 @@ __device__ inline FragB dequant_8bit(int q) { __half2* hi_ptr = reinterpret_cast<__half2*>(&hi); const __half2* magic_num_ptr = reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM); - frag_b[0] = __hsub(*lo_ptr, *magic_num_ptr); - frag_b[1] = __hsub(*hi_ptr, *magic_num_ptr); + frag_b[0] = __hsub2(*lo_ptr, *magic_num_ptr); + frag_b[1] = __hsub2(*hi_ptr, *magic_num_ptr); #else // NVIDIA implementation frag_b[0] = __hsub2(*reinterpret_cast(&lo), @@ -229,8 +270,8 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { #ifdef USE_ROCM // AMD implementation __half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul(frag_b[0], s); - frag_b[1] = __hmul(frag_b[1], s); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); #else // NVIDIA implementation half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); @@ -243,16 +284,16 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, FragS& s0, float* c4, float* c5, float* c6, float* c7, FragS& s1) { #ifdef USE_ROCM - // AMD implementation - *c0 = __builtin_amdgcn_fmul_legacy(*c0, __half2float(s0[0].x)); - *c1 = __builtin_amdgcn_fmul_legacy(*c1, __half2float(s0[0].y)); - *c2 = __builtin_amdgcn_fmul_legacy(*c2, __half2float(s0[1].x)); - *c3 = __builtin_amdgcn_fmul_legacy(*c3, __half2float(s0[1].y)); +// AMD MI300X implementation + *c0 = *c0 * __half2float(s0[0].x); + *c1 = *c1 * __half2float(s0[0].y); + *c2 = *c2 * __half2float(s0[1].x); + *c3 = *c3 * __half2float(s0[1].y); - *c4 = __builtin_amdgcn_fmul_legacy(*c4, __half2float(s1[0].x)); - *c5 = __builtin_amdgcn_fmul_legacy(*c5, __half2float(s1[0].y)); - *c6 = __builtin_amdgcn_fmul_legacy(*c6, __half2float(s1[1].x)); - *c7 = __builtin_amdgcn_fmul_legacy(*c7, __half2float(s1[1].y)); + *c4 = *c4 * __half2float(s1[0].x); + *c5 = *c5 * __half2float(s1[0].y); + *c6 = *c6 * __half2float(s1[1].x); + *c7 = *c7 * __half2float(s1[1].y); #else // NVIDIA implementation *c0 = __fmul_rn(*c0, __half2float(s0[0].x));