Skip to content

Commit 30bd924

Browse files
committed
Fix ROCm float multiplication in sparse Marlin MMA
Update AMD GPU implementation to use __builtin_amdgcn_fmul_f32 instead of __builtin_amdgcn_fmul_legacy for more accurate float multiplication in the scale_floats function.
1 parent a98a427 commit 30bd924

File tree

1 file changed

+9
-9
lines changed
  • torchao/csrc/cuda/sparse_marlin

1 file changed

+9
-9
lines changed

torchao/csrc/cuda/sparse_marlin/mma.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,16 +272,16 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
272272
FragS& s0, float* c4, float* c5, float* c6,
273273
float* c7, FragS& s1) {
274274
#ifdef USE_ROCM
275-
// AMD implementation
276-
*c0 = __builtin_amdgcn_fmul_legacy(*c0, __half2float(s0[0].x));
277-
*c1 = __builtin_amdgcn_fmul_legacy(*c1, __half2float(s0[0].y));
278-
*c2 = __builtin_amdgcn_fmul_legacy(*c2, __half2float(s0[1].x));
279-
*c3 = __builtin_amdgcn_fmul_legacy(*c3, __half2float(s0[1].y));
275+
// AMD implementation - fixed
276+
*c0 = __builtin_amdgcn_fmul_f32(*c0, __half2float(s0[0].x));
277+
*c1 = __builtin_amdgcn_fmul_f32(*c1, __half2float(s0[0].y));
278+
*c2 = __builtin_amdgcn_fmul_f32(*c2, __half2float(s0[1].x));
279+
*c3 = __builtin_amdgcn_fmul_f32(*c3, __half2float(s0[1].y));
280280

281-
*c4 = __builtin_amdgcn_fmul_legacy(*c4, __half2float(s1[0].x));
282-
*c5 = __builtin_amdgcn_fmul_legacy(*c5, __half2float(s1[0].y));
283-
*c6 = __builtin_amdgcn_fmul_legacy(*c6, __half2float(s1[1].x));
284-
*c7 = __builtin_amdgcn_fmul_legacy(*c7, __half2float(s1[1].y));
281+
*c4 = __builtin_amdgcn_fmul_f32(*c4, __half2float(s1[0].x));
282+
*c5 = __builtin_amdgcn_fmul_f32(*c5, __half2float(s1[0].y));
283+
*c6 = __builtin_amdgcn_fmul_f32(*c6, __half2float(s1[1].x));
284+
*c7 = __builtin_amdgcn_fmul_f32(*c7, __half2float(s1[1].y));
285285
#else
286286
// NVIDIA implementation
287287
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));

0 commit comments

Comments
 (0)