@@ -33,13 +33,41 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
33
#define USE_MERGE_MMA
34
34
#endif
35
35
36
+ FORCEINLINE void vec_load_pair2 (vec_bf16 * in0 , vec_bf16 * in )
37
+ {
38
+ vec_load_pair ((vec_f32 * )(in0 + 0 ), (vec_f32 * )(in + 0 ));
39
+ vec_load_pair ((vec_f32 * )(in0 + 2 ), (vec_f32 * )(in + 2 ));
40
+ }
41
+
36
42
FORCEINLINE void vec_load_mult_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 inp )
37
43
{
38
44
vec_bf16 in0 = (vec_bf16 )vec_load_vec (in );
39
45
40
46
__builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 , (vec_uc8 )inp );
41
47
}
42
48
49
+ FORCEINLINE void vec_load_mult12a_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 inp )
50
+ {
51
+ vec_bf16 in01 = (vec_bf16 )vec_load_vec (in0 );
52
+ vec_bf16 in11 = (vec_bf16 )vec_load_vec (in1 );
53
+
54
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 , (vec_uc8 )inp );
55
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 , (vec_uc8 )inp );
56
+ }
57
+
58
+ FORCEINLINE void vec_load_mult14_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * in2 , vec_bf16 * in3 , vec_bf16 inp )
59
+ {
60
+ vec_bf16 in01 = (vec_bf16 )vec_load_vec (in0 );
61
+ vec_bf16 in11 = (vec_bf16 )vec_load_vec (in1 );
62
+ vec_bf16 in21 = (vec_bf16 )vec_load_vec (in2 );
63
+ vec_bf16 in31 = (vec_bf16 )vec_load_vec (in3 );
64
+
65
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 , (vec_uc8 )inp );
66
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 , (vec_uc8 )inp );
67
+ __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 , (vec_uc8 )inp );
68
+ __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 , (vec_uc8 )inp );
69
+ }
70
+
43
71
FORCEINLINE void vec_load_mult2_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 * inp )
44
72
{
45
73
vec_bf16 in0 [2 ];
@@ -50,13 +78,123 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *
50
78
__builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 [1 ], (vec_uc8 )inp [1 ]);
51
79
}
52
80
81
+ FORCEINLINE void vec_load_mult22_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * inp )
82
+ {
83
+ vec_bf16 in01 [2 ], in11 [2 ];
84
+
85
+ vec_load_pair ((vec_f32 * )in01 , (vec_f32 * )in0 );
86
+ vec_load_pair ((vec_f32 * )in11 , (vec_f32 * )in1 );
87
+
88
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [0 ], (vec_uc8 )inp [0 ]);
89
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [0 ], (vec_uc8 )inp [0 ]);
90
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [1 ], (vec_uc8 )inp [1 ]);
91
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [1 ], (vec_uc8 )inp [1 ]);
92
+ }
93
+
94
+ FORCEINLINE void vec_load_mult24_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * in2 , vec_bf16 * in3 , vec_bf16 * inp )
95
+ {
96
+ vec_bf16 in01 [2 ], in11 [2 ], in21 [2 ], in31 [2 ];
97
+
98
+ vec_load_pair ((vec_f32 * )in01 , (vec_f32 * )in0 );
99
+ vec_load_pair ((vec_f32 * )in11 , (vec_f32 * )in1 );
100
+ vec_load_pair ((vec_f32 * )in21 , (vec_f32 * )in2 );
101
+ vec_load_pair ((vec_f32 * )in31 , (vec_f32 * )in3 );
102
+
103
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [0 ], (vec_uc8 )inp [0 ]);
104
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [0 ], (vec_uc8 )inp [0 ]);
105
+ __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [0 ], (vec_uc8 )inp [0 ]);
106
+ __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [0 ], (vec_uc8 )inp [0 ]);
107
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [1 ], (vec_uc8 )inp [1 ]);
108
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [1 ], (vec_uc8 )inp [1 ]);
109
+ __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [1 ], (vec_uc8 )inp [1 ]);
110
+ __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [1 ], (vec_uc8 )inp [1 ]);
111
+ }
112
+
113
+ FORCEINLINE void vec_load_mult4_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 * inp )
114
+ {
115
+ vec_bf16 in0 [4 ];
116
+
117
+ vec_load_pair2 (in0 , in );
118
+
119
+ __builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 [0 ], (vec_uc8 )inp [0 ]);
120
+ __builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 [1 ], (vec_uc8 )inp [1 ]);
121
+ __builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 [2 ], (vec_uc8 )inp [2 ]);
122
+ __builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 [3 ], (vec_uc8 )inp [3 ]);
123
+ }
124
+
125
+ FORCEINLINE void vec_load_mult42_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * inp )
126
+ {
127
+ vec_bf16 in01 [4 ], in11 [4 ];
128
+
129
+ vec_load_pair2 (in01 , in0 );
130
+ vec_load_pair2 (in11 , in1 );
131
+
132
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [0 ], (vec_uc8 )inp [0 ]);
133
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [0 ], (vec_uc8 )inp [0 ]);
134
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [1 ], (vec_uc8 )inp [1 ]);
135
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [1 ], (vec_uc8 )inp [1 ]);
136
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [2 ], (vec_uc8 )inp [2 ]);
137
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [2 ], (vec_uc8 )inp [2 ]);
138
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [3 ], (vec_uc8 )inp [3 ]);
139
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [3 ], (vec_uc8 )inp [3 ]);
140
+ }
141
+
142
+ FORCEINLINE void vec_load_mult44_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * in2 , vec_bf16 * in3 , vec_bf16 * inp )
143
+ {
144
+ vec_bf16 in01 [4 ], in11 [4 ], in21 [4 ], in31 [4 ];
145
+
146
+ vec_load_pair2 (in01 , in0 );
147
+ vec_load_pair2 (in11 , in1 );
148
+ vec_load_pair2 (in21 , in2 );
149
+ vec_load_pair2 (in31 , in3 );
150
+
151
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [0 ], (vec_uc8 )inp [0 ]);
152
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [0 ], (vec_uc8 )inp [0 ]);
153
+ __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [0 ], (vec_uc8 )inp [0 ]);
154
+ __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [0 ], (vec_uc8 )inp [0 ]);
155
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [1 ], (vec_uc8 )inp [1 ]);
156
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [1 ], (vec_uc8 )inp [1 ]);
157
+ __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [1 ], (vec_uc8 )inp [1 ]);
158
+ __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [1 ], (vec_uc8 )inp [1 ]);
159
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [2 ], (vec_uc8 )inp [2 ]);
160
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [2 ], (vec_uc8 )inp [2 ]);
161
+ __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [2 ], (vec_uc8 )inp [2 ]);
162
+ __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [2 ], (vec_uc8 )inp [2 ]);
163
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [3 ], (vec_uc8 )inp [3 ]);
164
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [3 ], (vec_uc8 )inp [3 ]);
165
+ __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [3 ], (vec_uc8 )inp [3 ]);
166
+ __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [3 ], (vec_uc8 )inp [3 ]);
167
+ }
168
+
53
169
FORCEINLINE void vec_loadN_mult_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 inp , BLASLONG n )
54
170
{
55
171
vec_bf16 in0 = vec_loadN (in , n );
56
172
57
173
__builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 , (vec_uc8 )inp );
58
174
}
59
175
176
+ FORCEINLINE void vec_loadN_mult12a_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 inp , BLASLONG n )
177
+ {
178
+ vec_bf16 in01 = (vec_bf16 )vec_loadN (in0 , n );
179
+ vec_bf16 in11 = (vec_bf16 )vec_loadN (in1 , n );
180
+
181
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 , (vec_uc8 )inp );
182
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 , (vec_uc8 )inp );
183
+ }
184
+
185
+ FORCEINLINE void vec_loadN_mult14_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * in2 , vec_bf16 * in3 , vec_bf16 inp , BLASLONG n )
186
+ {
187
+ vec_bf16 in01 = (vec_bf16 )vec_loadN (in0 , n );
188
+ vec_bf16 in11 = (vec_bf16 )vec_loadN (in1 , n );
189
+ vec_bf16 in21 = (vec_bf16 )vec_loadN (in2 , n );
190
+ vec_bf16 in31 = (vec_bf16 )vec_loadN (in3 , n );
191
+
192
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 , (vec_uc8 )inp );
193
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 , (vec_uc8 )inp );
194
+ __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 , (vec_uc8 )inp );
195
+ __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 , (vec_uc8 )inp );
196
+ }
197
+
60
198
FORCEINLINE void vec_mult1_mma (__vector_quad * out , vec_bf16 in0 , vec_bf16 inp )
61
199
{
62
200
vec_bf16 in00 = vec_mergeh (in0 , in0 );
0 commit comments