Skip to content

Commit 72c2642

Browse files
committed
Add ROCm-specific inline assembly for sparse Marlin MMA operations
Add conditional compilation for ROCm platforms in the sparse Marlin matrix multiply accumulate (MMA) function. This ensures proper inline assembly implementation for both CUDA and ROCm environments, using platform-specific register and instruction handling.
1 parent 94d1fb4 commit 72c2642

File tree

1 file changed

+34
-0
lines changed
  • torchao/csrc/cuda/sparse_marlin

1 file changed

+34
-0
lines changed

torchao/csrc/cuda/sparse_marlin/mma.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,22 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
5353

5454
float* c = reinterpret_cast<float*>(&frag_c);
5555
if (psel == 0) {
56+
#ifdef USE_ROCM
57+
asm volatile(MMA_SP_INST
58+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
59+
"{%12,%13,%14,%15}, %16, 0x0;\n"
60+
: "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3])
61+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
62+
"r"(b[2]), "r"(b[4]), "r"(b[6]), "v"(c[0]), "v"(c[1]),
63+
"v"(c[2]), "v"(c[3]), "r"(e[0]));
64+
asm volatile(MMA_SP_INST
65+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
66+
"{%12,%13,%14,%15}, %16, 0x0;\n"
67+
: "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7])
68+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
69+
"r"(b[3]), "r"(b[5]), "r"(b[7]), "v"(c[4]), "v"(c[5]),
70+
"v"(c[6]), "v"(c[7]), "r"(e[0]));
71+
#else
5672
asm volatile(MMA_SP_INST
5773
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
5874
"{%12,%13,%14,%15}, %16, 0x0;\n"
@@ -67,7 +83,24 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
6783
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
6884
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
6985
"f"(c[6]), "f"(c[7]), "r"(e[0]));
86+
#endif
7087
} else {
88+
#ifdef USE_ROCM
89+
asm volatile(MMA_SP_INST
90+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
91+
"{%12,%13,%14,%15}, %16, 0x1;\n"
92+
: "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3])
93+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
94+
"r"(b[2]), "r"(b[4]), "r"(b[6]), "v"(c[0]), "v"(c[1]),
95+
"v"(c[2]), "v"(c[3]), "r"(e[0]));
96+
asm volatile(MMA_SP_INST
97+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
98+
"{%12,%13,%14,%15}, %16, 0x1;\n"
99+
: "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7])
100+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
101+
"r"(b[3]), "r"(b[5]), "r"(b[7]), "v"(c[4]), "v"(c[5]),
102+
"v"(c[6]), "v"(c[7]), "r"(e[0]));
103+
#else
71104
asm volatile(MMA_SP_INST
72105
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
73106
"{%12,%13,%14,%15}, %16, 0x1;\n"
@@ -82,6 +115,7 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
82115
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
83116
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
84117
"f"(c[6]), "f"(c[7]), "r"(e[0]));
118+
#endif
85119
}
86120
}
87121

0 commit comments

Comments
 (0)