Skip to content

Commit 75f4787

Browse files
committed
Fix ROCm half-precision conversion in sparse Marlin MMA
Use __builtin_bit_cast to correctly convert float pairs to half-precision uint32_t values for AMD GPU platforms, ensuring proper type handling in the sparse Marlin matrix multiply accumulate (MMA) implementation.
1 parent 72c2642 commit 75f4787

File tree

1 file changed

+4
-4
lines changed
  • torchao/csrc/cuda/sparse_marlin

1 file changed

+4
-4
lines changed

torchao/csrc/cuda/sparse_marlin/mma.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ namespace torchao {
2727
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
2828
// is not supported. On later versions of CUDA the version without ordered
2929
// metadata results in the following warning:
30-
// | Advisory: Modifier ‘.sp::ordered_metadata should be used on instruction
31-
// | mma instead of modifier ‘.sp’ as it is expected to have substantially
30+
// | Advisory: Modifier 'sp::ordered_metadata' should be used on instruction
31+
// | 'mma' instead of modifier 'sp' as it is expected to have substantially
3232
// | reduced performance on some future architectures
3333

3434
#if defined(USE_ROCM)
@@ -143,8 +143,8 @@ __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
143143
uint2 r;
144144
#ifdef USE_ROCM
145145
// AMD implementation
146-
r.x = __builtin_amdgcn_cvt_pkrtz(c0, c1);
147-
r.y = __builtin_amdgcn_cvt_pkrtz(c2, c3);
146+
r.x = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(c0, c1));
147+
r.y = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(c2, c3));
148148
#else
149149
// NVIDIA implementation
150150
asm("{\n\t"

0 commit comments

Comments
 (0)