@@ -53,6 +53,22 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
53
53
54
54
float * c = reinterpret_cast <float *>(&frag_c);
55
55
if (psel == 0 ) {
56
+ #ifdef USE_ROCM
57
+ asm volatile (MMA_SP_INST
58
+ " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
59
+ " {%12,%13,%14,%15}, %16, 0x0;\n "
60
+ : " =v" (c[0 ]), " =v" (c[1 ]), " =v" (c[2 ]), " =v" (c[3 ])
61
+ : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[0 ]),
62
+ " r" (b[2 ]), " r" (b[4 ]), " r" (b[6 ]), " v" (c[0 ]), " v" (c[1 ]),
63
+ " v" (c[2 ]), " v" (c[3 ]), " r" (e[0 ]));
64
+ asm volatile (MMA_SP_INST
65
+ " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
66
+ " {%12,%13,%14,%15}, %16, 0x0;\n "
67
+ : " =v" (c[4 ]), " =v" (c[5 ]), " =v" (c[6 ]), " =v" (c[7 ])
68
+ : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[1 ]),
69
+ " r" (b[3 ]), " r" (b[5 ]), " r" (b[7 ]), " v" (c[4 ]), " v" (c[5 ]),
70
+ " v" (c[6 ]), " v" (c[7 ]), " r" (e[0 ]));
71
+ #else
56
72
asm volatile (MMA_SP_INST
57
73
" {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
58
74
" {%12,%13,%14,%15}, %16, 0x0;\n "
@@ -67,7 +83,24 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
67
83
: " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[1 ]),
68
84
" r" (b[3 ]), " r" (b[5 ]), " r" (b[7 ]), " f" (c[4 ]), " f" (c[5 ]),
69
85
" f" (c[6 ]), " f" (c[7 ]), " r" (e[0 ]));
86
+ #endif
70
87
} else {
88
+ #ifdef USE_ROCM
89
+ asm volatile (MMA_SP_INST
90
+ " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
91
+ " {%12,%13,%14,%15}, %16, 0x1;\n "
92
+ : " =v" (c[0 ]), " =v" (c[1 ]), " =v" (c[2 ]), " =v" (c[3 ])
93
+ : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[0 ]),
94
+ " r" (b[2 ]), " r" (b[4 ]), " r" (b[6 ]), " v" (c[0 ]), " v" (c[1 ]),
95
+ " v" (c[2 ]), " v" (c[3 ]), " r" (e[0 ]));
96
+ asm volatile (MMA_SP_INST
97
+ " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
98
+ " {%12,%13,%14,%15}, %16, 0x1;\n "
99
+ : " =v" (c[4 ]), " =v" (c[5 ]), " =v" (c[6 ]), " =v" (c[7 ])
100
+ : " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[1 ]),
101
+ " r" (b[3 ]), " r" (b[5 ]), " r" (b[7 ]), " v" (c[4 ]), " v" (c[5 ]),
102
+ " v" (c[6 ]), " v" (c[7 ]), " r" (e[0 ]));
103
+ #else
71
104
asm volatile (MMA_SP_INST
72
105
" {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
73
106
" {%12,%13,%14,%15}, %16, 0x1;\n "
@@ -82,6 +115,7 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
82
115
: " r" (a0[0 ]), " r" (a1[0 ]), " r" (a0[1 ]), " r" (a1[1 ]), " r" (b[1 ]),
83
116
" r" (b[3 ]), " r" (b[5 ]), " r" (b[7 ]), " f" (c[4 ]), " f" (c[5 ]),
84
117
" f" (c[6 ]), " f" (c[7 ]), " r" (e[0 ]));
118
+ #endif
85
119
}
86
120
}
87
121
0 commit comments