Skip to content

Commit a98a427

Browse files
committed
Optimize ROCm half-precision operations in sparse Marlin MMA
Update AMD GPU implementation to use __hsub2 and __hmul2 intrinsics for improved performance and precision in half-precision sparse matrix multiply-accumulate computations.
1 parent cf79039 commit a98a427

File tree

1 file changed

+4
-4
lines changed
  • torchao/csrc/cuda/sparse_marlin

1 file changed

+4
-4
lines changed

torchao/csrc/cuda/sparse_marlin/mma.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ __device__ inline FragB dequant_8bit(int q) {
240240
__half2* hi_ptr = reinterpret_cast<__half2*>(&hi);
241241
const __half2* magic_num_ptr = reinterpret_cast<const __half2*>(&I8s_TO_F16s_MAGIC_NUM);
242242

243-
frag_b[0] = __hsub(*lo_ptr, *magic_num_ptr);
244-
frag_b[1] = __hsub(*hi_ptr, *magic_num_ptr);
243+
frag_b[0] = __hsub2(*lo_ptr, *magic_num_ptr);
244+
frag_b[1] = __hsub2(*hi_ptr, *magic_num_ptr);
245245
#else
246246
// NVIDIA implementation
247247
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
@@ -258,8 +258,8 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
258258
#ifdef USE_ROCM
259259
// AMD implementation
260260
__half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
261-
frag_b[0] = __hmul(frag_b[0], s);
262-
frag_b[1] = __hmul(frag_b[1], s);
261+
frag_b[0] = __hmul2(frag_b[0], s);
262+
frag_b[1] = __hmul2(frag_b[1], s);
263263
#else
264264
// NVIDIA implementation
265265
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);

0 commit comments

Comments
 (0)