Skip to content

Commit cf79039

Browse files
committed
Optimize half-precision operations in sparse Marlin MMA
Update CUDA half-precision operations using __hsub2 and __hfma2 intrinsics to improve performance and precision in sparse matrix multiply-accumulate (MMA) computations.
1 parent 75f4787 commit cf79039

File tree

1 file changed

+2
-2
lines changed
  • torchao/csrc/cuda/sparse_marlin

1 file changed

+2
-2
lines changed

torchao/csrc/cuda/sparse_marlin/mma.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ __device__ inline FragB dequant_4bit(int q) {
206206
const __half2* MUL_ptr = reinterpret_cast<const __half2*>(&MUL);
207207
const __half2* ADD_ptr = reinterpret_cast<const __half2*>(&ADD);
208208

209-
frag_b[0] = __hsub(*lo_ptr, *SUB_ptr);
210-
frag_b[1] = __hfma(*hi_ptr, *MUL_ptr, *ADD_ptr);
209+
frag_b[0] = __hsub2(*lo_ptr, *SUB_ptr);
210+
frag_b[1] = __hfma2(*hi_ptr, *MUL_ptr, *ADD_ptr);
211211
#else
212212
// NVIDIA implementation
213213
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),

0 commit comments

Comments
 (0)